From 27efd1fc7e99017764ae4b7e9cccfb3366b62a45 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Tue, 24 Nov 2020 12:38:56 +0300 Subject: [PATCH] Resolve dynamism in Serialize transformation (#3260) * Added Interval shape inference for NMS5 * Added dynamic shape resolver for Serialize pass * Refactored Serializatoin tests; Added tests with dynamic shapes(NMS5, ShapeOf) * Fixed python NMS5 test * Fixed LowLatecnyTests --- .../include/transformations/serialize.hpp | 1 + .../src/transformations/serialize.cpp | 74 +++++- .../ir_serialization/models/nms5.bin | Bin 0 -> 16 bytes .../ir_serialization/models/nms5.xml | 91 ++++++++ .../ir_serialization/models/shape_of.xml | 84 +++++++ .../ir_serialization/serialize.cpp | 217 ++++-------------- .../transformations/low_latency_test.cpp | 8 +- .../common_test_utils/ngraph_test_utils.cpp | 8 +- ngraph/core/src/op/non_max_suppression.cpp | 17 ++ .../tests/test_ngraph/test_reduction.py | 9 +- ngraph/test/type_prop/non_max_suppression.cpp | 26 +-- 11 files changed, 328 insertions(+), 207 deletions(-) create mode 100644 inference-engine/tests/functional/inference_engine/ir_serialization/models/nms5.bin create mode 100644 inference-engine/tests/functional/inference_engine/ir_serialization/models/nms5.xml create mode 100644 inference-engine/tests/functional/inference_engine/ir_serialization/models/shape_of.xml diff --git a/inference-engine/src/transformations/include/transformations/serialize.hpp b/inference-engine/src/transformations/include/transformations/serialize.hpp index 7a26024695b..e400e975399 100644 --- a/inference-engine/src/transformations/include/transformations/serialize.hpp +++ b/inference-engine/src/transformations/include/transformations/serialize.hpp @@ -5,6 +5,7 @@ #include +#include "ngraph/opsets/opset.hpp" #include "ngraph/pass/pass.hpp" #include "transformations_visibility.hpp" diff --git a/inference-engine/src/transformations/src/transformations/serialize.cpp b/inference-engine/src/transformations/src/transformations/serialize.cpp index 6d52fd9b349..fa1f473ddab 100644 --- a/inference-engine/src/transformations/src/transformations/serialize.cpp +++ b/inference-engine/src/transformations/src/transformations/serialize.cpp @@ -300,9 +300,73 @@ bool is_exec_graph(const ngraph::Function& f) { return false; } +bool resolve_dynamic_shapes(const ngraph::Function& f) { + const auto & f_results = f.get_results(); + if (std::all_of(f_results.begin(), f_results.end(), + [](std::shared_ptr results) { return !results->is_dynamic(); })) { + return false; + } + + auto f_clone = ngraph::clone_function(f); + + const auto & f_ops = f.get_ordered_ops(); + const auto & f_clone_ops = f_clone->get_ordered_ops(); + NGRAPH_CHECK(f_ops.size() == f_clone_ops.size(), "Unexpected get_ordered_ops method behaviour"); + + for (size_t id = 0; id < f_ops.size(); ++id) { + auto & op = f_ops[id]; + auto & clone_op = f_clone_ops[id]; + + if (auto op_subgraph = std::dynamic_pointer_cast(op)) { + resolve_dynamic_shapes(*op_subgraph->get_function()); + } + + op->validate_and_infer_types(); + clone_op->validate_and_infer_types(); + + // dynamic_to_static function converts dynamic dimensions to static using + // upperbound (get_max_length) dimension value. + auto dynamic_to_static = [](const PartialShape & shape) -> PartialShape { + if (shape.is_static() || shape.rank().is_dynamic()) { + return shape; + } + auto out_shape = PartialShape::dynamic(shape.rank()); + for (size_t i = 0; i < shape.rank().get_length(); ++i) { + const auto & in_dim = shape[i]; + out_shape[i] = (in_dim.is_dynamic() ? Dimension(in_dim.get_max_length()) : in_dim); + } + return out_shape; + }; + + OutputVector replacements(clone_op->get_output_size()); + if (!clone_op->constant_fold(replacements, clone_op->input_values())) { + for (size_t output_id = 0; output_id < clone_op->get_output_size(); ++output_id) { + clone_op->set_output_type(output_id, clone_op->output(output_id).get_element_type(), + dynamic_to_static(clone_op->output(output_id).get_partial_shape())); + op->set_output_type(output_id, clone_op->output(output_id).get_element_type(), + clone_op->output(output_id).get_partial_shape()); + } + } else { + for (size_t output_id = 0; output_id < clone_op->get_output_size(); ++output_id) { + op->set_output_type(output_id, replacements[output_id].get_element_type(), + replacements[output_id].get_partial_shape()); + } + + for (size_t i = 0; i < replacements.size(); ++i) { + auto node_output = clone_op->output(i); + auto replacement = replacements.at(i); + if (replacement.get_node_shared_ptr() && (node_output != replacement)) { + node_output.replace(replacement); + } + } + } + } + return true; +} + void ngfunction_2_irv10( pugi::xml_document& doc, std::vector& bin, - const ngraph::Function& f, + ngraph::Function& f, const std::map& custom_opsets) { const bool exec_graph = is_exec_graph(f); @@ -315,6 +379,8 @@ void ngfunction_2_irv10( create_layer_ids(f); std::unordered_set unique_names; + bool has_dynamic_shapes = resolve_dynamic_shapes(f); + for (const auto& n : f.get_ordered_ops()) { ngraph::Node* node = n.get(); @@ -332,7 +398,7 @@ void ngfunction_2_irv10( // pugi::xml_node data = layer.append_child("data"); - // general atributes + // general attributes std::string node_type_name{node->get_type_name()}; if (exec_graph) { visit_exec_graph_node(data, node_type_name, node); @@ -403,6 +469,10 @@ void ngfunction_2_irv10( edge.append_attribute("to-layer").set_value(e.to_layer); edge.append_attribute("to-port").set_value(e.to_port); } + // move back dynamic shapes + if (has_dynamic_shapes) { + f.validate_nodes_and_infer_types(); + } } } // namespace diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/models/nms5.bin b/inference-engine/tests/functional/inference_engine/ir_serialization/models/nms5.bin new file mode 100644 index 0000000000000000000000000000000000000000..59fda1f234d6484bb298e831a207f177dd0613cd GIT binary patch literal 16 Qcmd;LfB^@4V`F1`00ngcVgLXD literal 0 HcmV?d00001 diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/models/nms5.xml b/inference-engine/tests/functional/inference_engine/ir_serialization/models/nms5.xml new file mode 100644 index 00000000000..aefaf59022e --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/ir_serialization/models/nms5.xml @@ -0,0 +1,91 @@ + + + + + + + + 1 + 1 + 1000 + + + + + + + + 1 + 1000 + 4 + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 1000 + 4 + + + 1 + 1 + 1000 + + + + + + + + 10 + 3 + + + 10 + 3 + + + 1 + + + + + + + + 10 + 3 + + + + + + + + + + + + + diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/models/shape_of.xml b/inference-engine/tests/functional/inference_engine/ir_serialization/models/shape_of.xml new file mode 100644 index 00000000000..747f5fdebfa --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/ir_serialization/models/shape_of.xml @@ -0,0 +1,84 @@ + + + + + + + + 1 + 2 + 3 + + + + + + + + 1 + 2 + 3 + + + + + 1 + 2 + 3 + + + + + + + + 1 + 2 + 3 + + + + + 3 + + + + + + + + 1 + 2 + 3 + + + 3 + + + + + 1 + 2 + 3 + + + + + + + + 1 + 2 + 3 + + + + + + + + + + + + diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp b/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp index a901309ddb7..7357e7cb713 100644 --- a/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp +++ b/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp @@ -12,12 +12,31 @@ #define IR_SERIALIZATION_MODELS_PATH "" #endif -class SerializationTest : public ::testing::Test { -protected: - std::string test_name = - ::testing::UnitTest::GetInstance()->current_test_info()->name(); - std::string m_out_xml_path = test_name + ".xml"; - std::string m_out_bin_path = test_name + ".bin"; +typedef std::tuple SerializationParams; + +class SerializationTest: public CommonTestUtils::TestsCommon, + public testing::WithParamInterface { +public: + std::string m_out_xml_path; + std::string m_out_bin_path; + + void SetUp() override { + const auto & model_path = IR_SERIALIZATION_MODELS_PATH + std::get<0>(GetParam()); + + const std::string test_name = "test"; // ::testing::UnitTest::GetInstance()->current_test_info()->name(); + m_out_xml_path = test_name + ".xml"; + m_out_bin_path = test_name + ".bin"; + + InferenceEngine::Core ie; + auto expected = ie.ReadNetwork(model_path); + expected.serialize(m_out_xml_path, m_out_bin_path); + auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); + + bool success; + std::string message; + std::tie(success, message) = compare_functions(result.getFunction(), expected.getFunction()); + ASSERT_TRUE(success) << message; + } void TearDown() override { std::remove(m_out_xml_path.c_str()); @@ -25,177 +44,21 @@ protected: } }; -TEST_F(SerializationTest, BasicModel_MO) { - const std::string model = IR_SERIALIZATION_MODELS_PATH "add_abc.xml"; - const std::string weights = IR_SERIALIZATION_MODELS_PATH "add_abc.bin"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model, weights); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; +TEST_P(SerializationTest, CompareFunctions) { } -TEST_F(SerializationTest, BasicModel_ONNXImporter) { - const std::string model = IR_SERIALIZATION_MODELS_PATH "add_abc.prototxt"; +INSTANTIATE_TEST_CASE_P(IRSerialization, SerializationTest, + testing::Values(std::make_tuple("add_abc.xml"), + std::make_tuple("split_equal_parts_2d.xml"), + std::make_tuple("addmul_abc.xml"), + std::make_tuple("add_abc_initializers.xml"), + std::make_tuple("experimental_detectron_roi_feature_extractor.xml"), + std::make_tuple("experimental_detectron_detection_output.xml"), + std::make_tuple("nms5.xml"), + std::make_tuple("shape_of.xml"))); - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; -} - -TEST_F(SerializationTest, ModelWithMultipleOutputs_MO) { - const std::string model = - IR_SERIALIZATION_MODELS_PATH "split_equal_parts_2d.xml"; - const std::string weights = - IR_SERIALIZATION_MODELS_PATH "split_equal_parts_2d.bin"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model, weights); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - // Compare function does not support models with multiple outputs - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_FALSE(success) << message; -} - -TEST_F(SerializationTest, ModelWithMultipleOutputs_ONNXImporter) { - const std::string model = - IR_SERIALIZATION_MODELS_PATH "split_equal_parts_2d.prototxt"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - // Compare function does not support models with multiple outputs - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_FALSE(success) << message; -} - -TEST_F(SerializationTest, ModelWithMultipleLayers_MO) { - const std::string model = IR_SERIALIZATION_MODELS_PATH "addmul_abc.xml"; - const std::string weights = IR_SERIALIZATION_MODELS_PATH "addmul_abc.bin"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model, weights); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; -} - -TEST_F(SerializationTest, ModelWithMultipleLayers_ONNXImporter) { - const std::string model = - IR_SERIALIZATION_MODELS_PATH "addmul_abc.prototxt"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; -} - -TEST_F(SerializationTest, ModelWithConstants_MO) { - const std::string model = - IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml"; - const std::string weights = - IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model, weights); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; -} - -TEST_F(SerializationTest, ModelWithConstants_ONNXImporter) { - const std::string model = - IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.prototxt"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; -} - -TEST_F(SerializationTest, ExperimentalDetectronROIFeatureExtractor_MO) { - const std::string model = IR_SERIALIZATION_MODELS_PATH - "experimental_detectron_roi_feature_extractor.xml"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; -} - -TEST_F(SerializationTest, ExperimentalDetectronDetectionOutput_MO) { - const std::string model = IR_SERIALIZATION_MODELS_PATH - "experimental_detectron_detection_output.xml"; - - InferenceEngine::Core ie; - auto expected = ie.ReadNetwork(model); - expected.serialize(m_out_xml_path, m_out_bin_path); - auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); - - bool success; - std::string message; - std::tie(success, message) = - compare_functions(result.getFunction(), expected.getFunction()); - - ASSERT_TRUE(success) << message; -} +INSTANTIATE_TEST_CASE_P(ONNXSerialization, SerializationTest, + testing::Values(std::make_tuple("add_abc.prototxt"), + std::make_tuple("split_equal_parts_2d.prototxt"), + std::make_tuple("addmul_abc.prototxt"), + std::make_tuple("add_abc_initializers.prototxt"))); diff --git a/inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp index dc1db93da6f..3d15d463367 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp @@ -95,10 +95,10 @@ TEST(TransformationTests, LowLatencyLSTM) { auto lstm_cell = std::make_shared(squeeze, read_value_H, read_value_C, W, R, B, 128); auto assign_H = std::make_shared(lstm_cell->output(0), variable_name_H); auto assign_C = std::make_shared(lstm_cell->output(1), variable_name_C); - auto res_1 = std::make_shared(lstm_cell->output(0)); auto unsqueeze = std::make_shared(lstm_cell->output(0), axis); auto res_2 = std::make_shared(unsqueeze); - f_ref = std::make_shared(OutputVector{unsqueeze, res_1}, ParameterVector{Xi, H_t, C_t}); + auto res_1 = std::make_shared(lstm_cell->output(0)); + f_ref = std::make_shared(OutputVector{res_1, res_2}, ParameterVector{Xi, H_t, C_t}); f_ref->add_sinks({assign_C, assign_H}); assign_H->add_control_dependency(read_value_H); assign_C->add_control_dependency(read_value_C); @@ -340,10 +340,10 @@ TEST(TransformationTests, LowLatencyLSTMReshape) { auto lstm_cell = std::make_shared(squeeze, read_value_H, read_value_C, W, R, B, 128); auto assign_H = std::make_shared(lstm_cell->output(0), variable_name_H); auto assign_C = std::make_shared(lstm_cell->output(1), variable_name_C); - auto res_1 = std::make_shared(lstm_cell->output(0)); auto unsqueeze = std::make_shared(lstm_cell->output(0), axis); auto res_2 = std::make_shared(unsqueeze); - f_ref = std::make_shared(OutputVector{unsqueeze, res_1}, ParameterVector{Xi, H_t, C_t}); + auto res_1 = std::make_shared(lstm_cell->output(0)); + f_ref = std::make_shared(OutputVector{res_1, res_2}, ParameterVector{Xi, H_t, C_t}); f_ref->add_sinks({assign_C, assign_H}); assign_H->add_control_dependency(read_value_H); assign_C->add_control_dependency(read_value_C); diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index 6e7dc4d2f96..3794307e0c8 100644 --- a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -84,8 +84,12 @@ std::pair compare_functions( * - Do not check nodes attributes (requires visitor mechanism to be completed) */ - const auto& f1_results = f1->get_results(); - const auto& f2_results = f2->get_results(); + auto f1_results = f1->get_results(); + auto f2_results = f2->get_results(); + auto compare_nodes_by_name = [](const std::shared_ptr & l, const std::shared_ptr & r) + { return l->get_friendly_name() < r->get_friendly_name(); }; + std::sort(f1_results.begin(), f1_results.end(), compare_nodes_by_name); + std::sort(f2_results.begin(), f2_results.end(), compare_nodes_by_name); if (f1_results.size() != f2_results.size()) { return { false, "Number of results is different: " + std::to_string(f1_results.size()) + " and " + std::to_string(f2_results.size()) }; } diff --git a/ngraph/core/src/op/non_max_suppression.cpp b/ngraph/core/src/op/non_max_suppression.cpp index 3b5e92be0fa..d5e715b6865 100644 --- a/ngraph/core/src/op/non_max_suppression.cpp +++ b/ngraph/core/src/op/non_max_suppression.cpp @@ -902,6 +902,23 @@ void op::v5::NonMaxSuppression::validate_and_infer_types() validate(); + if (boxes_ps.rank().is_static() && scores_ps.rank().is_static() && get_input_size() > 2) + { + const auto num_boxes_boxes = boxes_ps[1]; + const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr(); + if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() && + op::is_constant(max_output_boxes_per_class_node)) + { + const auto num_boxes = num_boxes_boxes.get_length(); + const auto num_classes = scores_ps[1].get_length(); + const auto max_output_boxes_per_class = max_boxes_output_from_input(); + + out_shape[0] = Dimension(0, + std::min(num_boxes, max_output_boxes_per_class) * num_classes * + scores_ps[0].get_length()); + } + } + set_output_type(0, m_output_type, out_shape); set_output_type(1, element::f32, out_shape); set_output_type(2, m_output_type, Shape{1}); diff --git a/ngraph/python/tests/test_ngraph/test_reduction.py b/ngraph/python/tests/test_ngraph/test_reduction.py index 26739b92d98..30be0337f5e 100644 --- a/ngraph/python/tests/test_ngraph/test_reduction.py +++ b/ngraph/python/tests/test_ngraph/test_reduction.py @@ -15,9 +15,10 @@ # ****************************************************************************** import numpy as np import pytest -from _pyngraph import PartialShape +from _pyngraph import PartialShape, Dimension import ngraph as ng +from ngraph.utils.types import make_constant_node from tests.runtime import get_runtime from tests.test_ngraph.util import run_op_node from tests import xfail_issue_40957 @@ -108,12 +109,12 @@ def test_non_max_suppression(): boxes_parameter = ng.parameter(boxes_shape, name="Boxes", dtype=np.float32) scores_parameter = ng.parameter(scores_shape, name="Scores", dtype=np.float32) - node = ng.non_max_suppression(boxes_parameter, scores_parameter) + node = ng.non_max_suppression(boxes_parameter, scores_parameter, make_constant_node(1000, np.int64)) assert node.get_type_name() == "NonMaxSuppression" assert node.get_output_size() == 3 - assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 3])) - assert node.get_output_partial_shape(1).same_scheme(PartialShape([-1, 3])) + assert node.get_output_partial_shape(0) == PartialShape([Dimension(0, 1000), Dimension(3)]) + assert node.get_output_partial_shape(1) == PartialShape([Dimension(0, 1000), Dimension(3)]) assert list(node.get_output_shape(2)) == [1] diff --git a/ngraph/test/type_prop/non_max_suppression.cpp b/ngraph/test/type_prop/non_max_suppression.cpp index ab70f7cb457..8202486b25d 100644 --- a/ngraph/test/type_prop/non_max_suppression.cpp +++ b/ngraph/test/type_prop/non_max_suppression.cpp @@ -691,11 +691,9 @@ TEST(type_prop, nms_v5_output_shape_2) ASSERT_EQ(nms->get_output_element_type(0), element::i64); ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(2), element::i64); - ASSERT_TRUE( - nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); - ASSERT_TRUE( - nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); + EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension(0, 30), Dimension(3)})); + EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension(0, 30), Dimension(3)})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); } @@ -713,11 +711,8 @@ TEST(type_prop, nms_v5_output_shape_3) ASSERT_EQ(nms->get_output_element_type(0), element::i64); ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(2), element::i64); - ASSERT_TRUE( - nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); - ASSERT_TRUE( - nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); - + EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension(0, 70), Dimension(3)})); + EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension(0, 70), Dimension(3)})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); } @@ -742,11 +737,9 @@ TEST(type_prop, nms_v5_output_shape_i32) ASSERT_EQ(nms->get_output_element_type(0), element::i32); ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(2), element::i32); - ASSERT_TRUE( - nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); - ASSERT_TRUE( - nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); + EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension(0, 30), Dimension(3)})); + EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension(0, 30), Dimension(3)})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); } @@ -764,10 +757,7 @@ TEST(type_prop, nms_v5_dynamic_boxes_and_scores) ASSERT_EQ(nms->get_output_element_type(0), element::i64); ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(2), element::i64); - ASSERT_TRUE( - nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); - ASSERT_TRUE( - nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3})); - + EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 3})); + EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension::dynamic(), 3})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); }