From 4d81bd9e0ecb6d1d3ce95f7b669b9dbae6af7250 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Thu, 10 Dec 2020 14:07:20 +0300 Subject: [PATCH] Parametrize NMS5ToLegacy conversion to avoid Convert operatoins insertion that breaks outputs naming (#3480) --- .../cnn_network_ngraph_impl.cpp | 2 +- .../convert_nms_5_to_legacy.hpp | 2 +- .../convert_nms_5_to_legacy.cpp | 9 ++++---- .../cnn_network/cnn_ngraph_impl_tests.cpp | 23 +++++++++++++++++++ 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp index 8a4bc188a0d..dd3160cd53d 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp @@ -340,7 +340,7 @@ CNNNetworkNGraphImpl::reshape(const std::map>& OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertToLegacy"); ::ngraph::pass::Manager manager; // resolves dynamism by replacing dynamic operation with static version - manager.register_pass<::ngraph::pass::ConvertNMS5ToLegacyMatcher>(); + manager.register_pass<::ngraph::pass::ConvertNMS5ToLegacyMatcher>(false); manager.register_pass<::ngraph::pass::ConstantFolding>(); // OneHotToLegacy changes output precision manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type( diff --git a/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp b/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp index 375ec6cdc71..a352caa8b07 100644 --- a/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp +++ b/inference-engine/src/legacy_api/include/legacy/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp @@ -28,6 +28,6 @@ class INFERENCE_ENGINE_API_CLASS(ConvertNMS5ToLegacyMatcher); class ngraph::pass::ConvertNMS5ToLegacyMatcher: public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; - ConvertNMS5ToLegacyMatcher(); + ConvertNMS5ToLegacyMatcher(bool force_i32_output_type = true); }; diff --git a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.cpp b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.cpp index d9dad3b51f0..510832ebf02 100644 --- a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.cpp +++ b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.cpp @@ -17,10 +17,10 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNMS5ToLegacyMatcher, "ConvertNMS5ToLegacyMatcher", 0); -ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() { +ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher(bool force_i32_output_type) { auto nms = ngraph::pattern::wrap_type(); - ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) { + ngraph::matcher_pass_callback callback = [force_i32_output_type](pattern::Matcher &m) { auto nms_5 = std::dynamic_pointer_cast(m.get_match_root()); if (!nms_5) { return false; @@ -72,6 +72,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() { std::shared_ptr nms_legacy{nullptr}; + auto output_type = force_i32_output_type ? element::i32 : nms_5->get_output_type(); if (num_of_inputs > 5 && nms_5->soft_nms_sigma_from_input() != 0.0f) { new_soft_nms_sigma = std::make_shared(new_args.at(5), new_shape_for_soft_nms_sigma, true); new_ops.emplace_back(new_soft_nms_sigma.get_node_shared_ptr()); @@ -84,7 +85,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() { new_soft_nms_sigma, center_point_box, nms_5->get_sort_result_descending(), - element::i32); + output_type); new_ops.push_back(nms_legacy); } else { nms_legacy = std::make_shared( @@ -95,7 +96,7 @@ ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() { new_score_threshold, center_point_box, nms_5->get_sort_result_descending(), - element::i32); + output_type); new_ops.push_back(nms_legacy); } diff --git a/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp b/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp index a5789bee038..fdbdd8dbc13 100644 --- a/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -44,6 +45,28 @@ using namespace InferenceEngine; IE_SUPPRESS_DEPRECATED_START +TEST(CNNNGraphImplTests, TestNMS5OutputNames) { + std::shared_ptr f; + { + auto boxes = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 1000, 4}); + auto scores = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 1, 1000}); + auto max_output_boxes_per_class = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10}); + auto iou_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.75}); + auto score_threshold = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {0.7}); + auto nms = std::make_shared(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, + ngraph::opset5::NonMaxSuppression::BoxEncodingType::CORNER, true); + nms->set_friendly_name("nms"); + f = std::make_shared(ngraph::OutputVector{nms->output(0), nms->output(1), nms->output(2)}, ngraph::ParameterVector{boxes, scores}); + } + + InferenceEngine::CNNNetwork cnnNet(f); + auto outputs_info = cnnNet.getOutputsInfo(); + ASSERT_EQ(outputs_info.size(), 3); + ASSERT_EQ(outputs_info.count("nms.0"), 1); + ASSERT_EQ(outputs_info.count("nms.1"), 1); + ASSERT_EQ(outputs_info.count("nms.2"), 1); +} + TEST(CNNNGraphImplTests, TestConvertWithRemoveLastLayerNetwork) { std::shared_ptr ngraph; {