Do not remove convert after the topK (#1950)

* Do not remove convert after the topK

* Added debug message

* Removed xFail

* Revert "Added debug message"

This reverts commit a01ace4ade88d73e2797b47c58db33943b0f508d.

* Added test
This commit is contained in:
Ilya Churaev 2020-09-01 14:36:30 +03:00 committed by GitHub
parent 7796b0f277
commit c9820a9588
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 4 deletions

View File

@ -179,7 +179,9 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
// TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert elimination here
for (auto &node : f->get_ordered_ops()) {
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node)) {
if (convert->input(0).get_element_type() == convert->get_convert_element_type()) {
// WA for topK, dont remove fake convert
if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1) {
replace_output_update_name(convert->output(0), convert->input_value(0));
}
}

View File

@ -7,6 +7,12 @@
#include <cpp/ie_cnn_network.h>
#include <legacy/cnn_network_impl.hpp> // deprecated API
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
#include <transformations/convert_precision.hpp>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset4.hpp>
@ -143,3 +149,56 @@ TEST(ConvertFunctionToCNNNetworkTests, OpsShouldBeConvertedToIERepresentation) {
}
}
}
TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
std::shared_ptr<ngraph::Function> f;
{
auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 22, 22});
ngraph::Shape const_shape = {};
std::vector<int64_t> val = {5};
auto k = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, const_shape, val);
auto topK = std::make_shared<ngraph::opset4::TopK>(param, k, 2, ngraph::opset4::TopK::Mode::MAX, ngraph::opset4::TopK::SortType::SORT_VALUES);
topK->set_friendly_name("topK");
auto result = std::make_shared<ngraph::op::Result>(topK->output(1));
f = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{param});
ngraph::pass::InitNodeInfo().run_on_function(f);
}
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::CommonOptimizations>();
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
{ngraph::element::i64, ngraph::element::i32},
{ngraph::element::u64, ngraph::element::i32},
{ngraph::element::u16, ngraph::element::i32},
{ngraph::element::u32, ngraph::element::i32},
{ngraph::element::f16, ngraph::element::f32},
{ngraph::element::boolean, ngraph::element::u8},
};
for (auto & precision : convert_precision_list) {
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
}
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
manager.run_passes(f);
InferenceEngine::CNNNetwork nGraphImpl(f);
nGraphImpl = CNNNetwork(InferenceEngine::details::convertFunctionToICNNNetwork(f, nGraphImpl));
try {
OutputsDataMap outputs = nGraphImpl.getOutputsInfo();
ASSERT_EQ(outputs.size(), 1);
ASSERT_EQ(outputs.begin()->first, "topK.1");
} catch (InferenceEngine::details::InferenceEngineException &err) {
const std::string ref_msg = "Error of validate layer: prelu with type: PReLU. Number of inputs (2) is not equal to expected ones: 1";
const std::string resp_msg = err.what();
ASSERT_TRUE(resp_msg.find(ref_msg) != std::string::npos) << resp_msg;
}
}

View File

@ -18,7 +18,7 @@ import onnx
import pytest
from tests.test_onnx.utils import run_node
from tests import xfail_issue_35925, xfail_issue_36437
from tests import xfail_issue_35925
reduce_data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
reduce_axis_parameters = [
@ -268,7 +268,6 @@ def test_reduce_sum_square_default_axes():
assert np.allclose(expected, ng_result)
@xfail_issue_36437
def test_reduce_argmin():
def argmin(ndarray, axis, keepdims=False):
res = np.argmin(ndarray, axis=axis)
@ -292,7 +291,6 @@ def test_reduce_argmin():
)
@xfail_issue_36437
def test_reduce_argmax():
def argmax(ndarray, axis, keepdims=False):
res = np.argmax(ndarray, axis=axis)