Do not remove convert after the topK (#1950)
* Do not remove convert after the topK * Added debug message * Removed xFail * Revert "Added debug message" This reverts commit a01ace4ade88d73e2797b47c58db33943b0f508d. * Added test
This commit is contained in:
parent
7796b0f277
commit
c9820a9588
@ -179,7 +179,9 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
|
||||
// TODO: we need to split NopElimination pass to separate MatcherPasses and call Convert elimination here
|
||||
for (auto &node : f->get_ordered_ops()) {
|
||||
if (auto convert = std::dynamic_pointer_cast<opset4::Convert>(node)) {
|
||||
if (convert->input(0).get_element_type() == convert->get_convert_element_type()) {
|
||||
// WA for topK, dont remove fake convert
|
||||
if (convert->input(0).get_element_type() == convert->get_convert_element_type() &&
|
||||
convert->input_value(0).get_node_shared_ptr()->get_output_size() == 1) {
|
||||
replace_output_update_name(convert->output(0), convert->input_value(0));
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,12 @@
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
#include <legacy/cnn_network_impl.hpp> // deprecated API
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
|
||||
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
|
||||
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
@ -143,3 +149,56 @@ TEST(ConvertFunctionToCNNNetworkTests, OpsShouldBeConvertedToIERepresentation) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConvertFunctionToCNNNetworkTests, ConvertTopKWithOneInput) {
|
||||
std::shared_ptr<ngraph::Function> f;
|
||||
{
|
||||
auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 22, 22});
|
||||
ngraph::Shape const_shape = {};
|
||||
std::vector<int64_t> val = {5};
|
||||
auto k = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, const_shape, val);
|
||||
auto topK = std::make_shared<ngraph::opset4::TopK>(param, k, 2, ngraph::opset4::TopK::Mode::MAX, ngraph::opset4::TopK::SortType::SORT_VALUES);
|
||||
topK->set_friendly_name("topK");
|
||||
auto result = std::make_shared<ngraph::op::Result>(topK->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{param});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
}
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::CommonOptimizations>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
|
||||
|
||||
std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>> convert_precision_list {
|
||||
{ngraph::element::i64, ngraph::element::i32},
|
||||
{ngraph::element::u64, ngraph::element::i32},
|
||||
{ngraph::element::u16, ngraph::element::i32},
|
||||
{ngraph::element::u32, ngraph::element::i32},
|
||||
{ngraph::element::f16, ngraph::element::f32},
|
||||
{ngraph::element::boolean, ngraph::element::u8},
|
||||
};
|
||||
|
||||
for (auto & precision : convert_precision_list) {
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precision.first, precision.second);
|
||||
}
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
InferenceEngine::CNNNetwork nGraphImpl(f);
|
||||
nGraphImpl = CNNNetwork(InferenceEngine::details::convertFunctionToICNNNetwork(f, nGraphImpl));
|
||||
|
||||
try {
|
||||
OutputsDataMap outputs = nGraphImpl.getOutputsInfo();
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
ASSERT_EQ(outputs.begin()->first, "topK.1");
|
||||
} catch (InferenceEngine::details::InferenceEngineException &err) {
|
||||
const std::string ref_msg = "Error of validate layer: prelu with type: PReLU. Number of inputs (2) is not equal to expected ones: 1";
|
||||
const std::string resp_msg = err.what();
|
||||
ASSERT_TRUE(resp_msg.find(ref_msg) != std::string::npos) << resp_msg;
|
||||
}
|
||||
}
|
@ -18,7 +18,7 @@ import onnx
|
||||
import pytest
|
||||
|
||||
from tests.test_onnx.utils import run_node
|
||||
from tests import xfail_issue_35925, xfail_issue_36437
|
||||
from tests import xfail_issue_35925
|
||||
|
||||
reduce_data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
|
||||
reduce_axis_parameters = [
|
||||
@ -268,7 +268,6 @@ def test_reduce_sum_square_default_axes():
|
||||
assert np.allclose(expected, ng_result)
|
||||
|
||||
|
||||
@xfail_issue_36437
|
||||
def test_reduce_argmin():
|
||||
def argmin(ndarray, axis, keepdims=False):
|
||||
res = np.argmin(ndarray, axis=axis)
|
||||
@ -292,7 +291,6 @@ def test_reduce_argmin():
|
||||
)
|
||||
|
||||
|
||||
@xfail_issue_36437
|
||||
def test_reduce_argmax():
|
||||
def argmax(ndarray, axis, keepdims=False):
|
||||
res = np.argmax(ndarray, axis=axis)
|
||||
|
Loading…
Reference in New Issue
Block a user