Resolve dynamism in Serialize transformation (#3260)

* Added Interval shape inference for NMS5

* Added dynamic shape resolver for Serialize pass

* Refactored Serializatoin tests; Added tests with dynamic shapes(NMS5, ShapeOf)

* Fixed python NMS5 test

* Fixed LowLatecnyTests
This commit is contained in:
Gleb Kazantaev 2020-11-24 12:38:56 +03:00 committed by GitHub
parent 3d75c3e863
commit 27efd1fc7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 328 additions and 207 deletions

View File

@ -5,6 +5,7 @@
#include <string> #include <string>
#include "ngraph/opsets/opset.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "transformations_visibility.hpp" #include "transformations_visibility.hpp"

View File

@ -300,9 +300,73 @@ bool is_exec_graph(const ngraph::Function& f) {
return false; return false;
} }
bool resolve_dynamic_shapes(const ngraph::Function& f) {
const auto & f_results = f.get_results();
if (std::all_of(f_results.begin(), f_results.end(),
[](std::shared_ptr<Node> results) { return !results->is_dynamic(); })) {
return false;
}
auto f_clone = ngraph::clone_function(f);
const auto & f_ops = f.get_ordered_ops();
const auto & f_clone_ops = f_clone->get_ordered_ops();
NGRAPH_CHECK(f_ops.size() == f_clone_ops.size(), "Unexpected get_ordered_ops method behaviour");
for (size_t id = 0; id < f_ops.size(); ++id) {
auto & op = f_ops[id];
auto & clone_op = f_clone_ops[id];
if (auto op_subgraph = std::dynamic_pointer_cast<op::util::SubGraphOp>(op)) {
resolve_dynamic_shapes(*op_subgraph->get_function());
}
op->validate_and_infer_types();
clone_op->validate_and_infer_types();
// dynamic_to_static function converts dynamic dimensions to static using
// upperbound (get_max_length) dimension value.
auto dynamic_to_static = [](const PartialShape & shape) -> PartialShape {
if (shape.is_static() || shape.rank().is_dynamic()) {
return shape;
}
auto out_shape = PartialShape::dynamic(shape.rank());
for (size_t i = 0; i < shape.rank().get_length(); ++i) {
const auto & in_dim = shape[i];
out_shape[i] = (in_dim.is_dynamic() ? Dimension(in_dim.get_max_length()) : in_dim);
}
return out_shape;
};
OutputVector replacements(clone_op->get_output_size());
if (!clone_op->constant_fold(replacements, clone_op->input_values())) {
for (size_t output_id = 0; output_id < clone_op->get_output_size(); ++output_id) {
clone_op->set_output_type(output_id, clone_op->output(output_id).get_element_type(),
dynamic_to_static(clone_op->output(output_id).get_partial_shape()));
op->set_output_type(output_id, clone_op->output(output_id).get_element_type(),
clone_op->output(output_id).get_partial_shape());
}
} else {
for (size_t output_id = 0; output_id < clone_op->get_output_size(); ++output_id) {
op->set_output_type(output_id, replacements[output_id].get_element_type(),
replacements[output_id].get_partial_shape());
}
for (size_t i = 0; i < replacements.size(); ++i) {
auto node_output = clone_op->output(i);
auto replacement = replacements.at(i);
if (replacement.get_node_shared_ptr() && (node_output != replacement)) {
node_output.replace(replacement);
}
}
}
}
return true;
}
void ngfunction_2_irv10( void ngfunction_2_irv10(
pugi::xml_document& doc, std::vector<uint8_t>& bin, pugi::xml_document& doc, std::vector<uint8_t>& bin,
const ngraph::Function& f, ngraph::Function& f,
const std::map<std::string, ngraph::OpSet>& custom_opsets) { const std::map<std::string, ngraph::OpSet>& custom_opsets) {
const bool exec_graph = is_exec_graph(f); const bool exec_graph = is_exec_graph(f);
@ -315,6 +379,8 @@ void ngfunction_2_irv10(
create_layer_ids(f); create_layer_ids(f);
std::unordered_set<std::string> unique_names; std::unordered_set<std::string> unique_names;
bool has_dynamic_shapes = resolve_dynamic_shapes(f);
for (const auto& n : f.get_ordered_ops()) { for (const auto& n : f.get_ordered_ops()) {
ngraph::Node* node = n.get(); ngraph::Node* node = n.get();
@ -332,7 +398,7 @@ void ngfunction_2_irv10(
// <layers/data> // <layers/data>
pugi::xml_node data = layer.append_child("data"); pugi::xml_node data = layer.append_child("data");
// <layers/data> general atributes // <layers/data> general attributes
std::string node_type_name{node->get_type_name()}; std::string node_type_name{node->get_type_name()};
if (exec_graph) { if (exec_graph) {
visit_exec_graph_node(data, node_type_name, node); visit_exec_graph_node(data, node_type_name, node);
@ -403,6 +469,10 @@ void ngfunction_2_irv10(
edge.append_attribute("to-layer").set_value(e.to_layer); edge.append_attribute("to-layer").set_value(e.to_layer);
edge.append_attribute("to-port").set_value(e.to_port); edge.append_attribute("to-port").set_value(e.to_port);
} }
// move back dynamic shapes
if (has_dynamic_shapes) {
f.validate_nodes_and_infer_types();
}
} }
} // namespace } // namespace

View File

@ -0,0 +1,91 @@
<?xml version="1.0"?>
<net name="Function_0" version="10">
<layers>
<layer id="0" name="Parameter_69" type="Parameter" version="opset1">
<data cacheable="false" shape="1,1,1000" element_type="f32" />
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>1</dim>
<dim>1000</dim>
</port>
</output>
</layer>
<layer id="1" name="Parameter_68" type="Parameter" version="opset1">
<data cacheable="false" shape="1,1000,4" element_type="f32" />
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>1000</dim>
<dim>4</dim>
</port>
</output>
</layer>
<layer id="2" name="Constant_70" type="Const" version="opset1">
<data element_type="i64" shape="" offset="0" size="8" />
<output>
<port id="0" precision="I64" />
</output>
</layer>
<layer id="3" name="Constant_71" type="Const" version="opset1">
<data element_type="f32" shape="" offset="8" size="4" />
<output>
<port id="0" precision="FP32" />
</output>
</layer>
<layer id="4" name="Constant_72" type="Const" version="opset1">
<data element_type="f32" shape="" offset="12" size="4" />
<output>
<port id="0" precision="FP32" />
</output>
</layer>
<layer id="5" name="NonMaxSuppression_73" type="NonMaxSuppression" version="opset5">
<data box_encoding="corner" sort_result_descending="true" output_type="i64" />
<input>
<port id="0">
<dim>1</dim>
<dim>1000</dim>
<dim>4</dim>
</port>
<port id="1">
<dim>1</dim>
<dim>1</dim>
<dim>1000</dim>
</port>
<port id="2" />
<port id="3" />
<port id="4" />
</input>
<output>
<port id="5" precision="I64">
<dim>10</dim>
<dim>3</dim>
</port>
<port id="6" precision="FP32">
<dim>10</dim>
<dim>3</dim>
</port>
<port id="7" precision="I64">
<dim>1</dim>
</port>
</output>
</layer>
<layer id="6" name="Result_74" type="Result" version="opset1">
<data />
<input>
<port id="0">
<dim>10</dim>
<dim>3</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="5" to-port="1" />
<edge from-layer="1" from-port="0" to-layer="5" to-port="0" />
<edge from-layer="2" from-port="0" to-layer="5" to-port="2" />
<edge from-layer="3" from-port="0" to-layer="5" to-port="3" />
<edge from-layer="4" from-port="0" to-layer="5" to-port="4" />
<edge from-layer="5" from-port="5" to-layer="6" to-port="0" />
</edges>
</net>

View File

@ -0,0 +1,84 @@
<?xml version="1.0"?>
<net name="Function_0" version="10">
<layers>
<layer id="0" name="Parameter_68" type="Parameter" version="opset1">
<data cacheable="false" shape="1,2,3" element_type="f32" />
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>2</dim>
<dim>3</dim>
</port>
</output>
</layer>
<layer id="1" name="Relu_70" type="Relu" version="opset1">
<data />
<input>
<port id="0">
<dim>1</dim>
<dim>2</dim>
<dim>3</dim>
</port>
</input>
<output>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>2</dim>
<dim>3</dim>
</port>
</output>
</layer>
<layer id="2" name="ShapeOf_69" type="ShapeOf" version="opset3">
<data output_type="i64" />
<input>
<port id="0">
<dim>1</dim>
<dim>2</dim>
<dim>3</dim>
</port>
</input>
<output>
<port id="1" precision="I64">
<dim>3</dim>
</port>
</output>
</layer>
<layer id="3" name="Reshape_71" type="Reshape" version="opset1">
<data special_zero="true" />
<input>
<port id="0">
<dim>1</dim>
<dim>2</dim>
<dim>3</dim>
</port>
<port id="1">
<dim>3</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>2</dim>
<dim>3</dim>
</port>
</output>
</layer>
<layer id="4" name="Result_72" type="Result" version="opset1">
<data />
<input>
<port id="0">
<dim>1</dim>
<dim>2</dim>
<dim>3</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="1" to-port="0" />
<edge from-layer="0" from-port="0" to-layer="2" to-port="0" />
<edge from-layer="1" from-port="1" to-layer="3" to-port="0" />
<edge from-layer="2" from-port="1" to-layer="3" to-port="1" />
<edge from-layer="3" from-port="2" to-layer="4" to-port="0" />
</edges>
</net>

View File

@ -12,12 +12,31 @@
#define IR_SERIALIZATION_MODELS_PATH "" #define IR_SERIALIZATION_MODELS_PATH ""
#endif #endif
class SerializationTest : public ::testing::Test { typedef std::tuple<std::string> SerializationParams;
protected:
std::string test_name = class SerializationTest: public CommonTestUtils::TestsCommon,
::testing::UnitTest::GetInstance()->current_test_info()->name(); public testing::WithParamInterface<SerializationParams> {
std::string m_out_xml_path = test_name + ".xml"; public:
std::string m_out_bin_path = test_name + ".bin"; std::string m_out_xml_path;
std::string m_out_bin_path;
void SetUp() override {
const auto & model_path = IR_SERIALIZATION_MODELS_PATH + std::get<0>(GetParam());
const std::string test_name = "test"; // ::testing::UnitTest::GetInstance()->current_test_info()->name();
m_out_xml_path = test_name + ".xml";
m_out_bin_path = test_name + ".bin";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model_path);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) = compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}
void TearDown() override { void TearDown() override {
std::remove(m_out_xml_path.c_str()); std::remove(m_out_xml_path.c_str());
@ -25,177 +44,21 @@ protected:
} }
}; };
TEST_F(SerializationTest, BasicModel_MO) { TEST_P(SerializationTest, CompareFunctions) {
const std::string model = IR_SERIALIZATION_MODELS_PATH "add_abc.xml";
const std::string weights = IR_SERIALIZATION_MODELS_PATH "add_abc.bin";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model, weights);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
} }
TEST_F(SerializationTest, BasicModel_ONNXImporter) { INSTANTIATE_TEST_CASE_P(IRSerialization, SerializationTest,
const std::string model = IR_SERIALIZATION_MODELS_PATH "add_abc.prototxt"; testing::Values(std::make_tuple("add_abc.xml"),
std::make_tuple("split_equal_parts_2d.xml"),
std::make_tuple("addmul_abc.xml"),
std::make_tuple("add_abc_initializers.xml"),
std::make_tuple("experimental_detectron_roi_feature_extractor.xml"),
std::make_tuple("experimental_detectron_detection_output.xml"),
std::make_tuple("nms5.xml"),
std::make_tuple("shape_of.xml")));
InferenceEngine::Core ie; INSTANTIATE_TEST_CASE_P(ONNXSerialization, SerializationTest,
auto expected = ie.ReadNetwork(model); testing::Values(std::make_tuple("add_abc.prototxt"),
expected.serialize(m_out_xml_path, m_out_bin_path); std::make_tuple("split_equal_parts_2d.prototxt"),
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path); std::make_tuple("addmul_abc.prototxt"),
std::make_tuple("add_abc_initializers.prototxt")));
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}
TEST_F(SerializationTest, ModelWithMultipleOutputs_MO) {
const std::string model =
IR_SERIALIZATION_MODELS_PATH "split_equal_parts_2d.xml";
const std::string weights =
IR_SERIALIZATION_MODELS_PATH "split_equal_parts_2d.bin";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model, weights);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
// Compare function does not support models with multiple outputs
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_FALSE(success) << message;
}
TEST_F(SerializationTest, ModelWithMultipleOutputs_ONNXImporter) {
const std::string model =
IR_SERIALIZATION_MODELS_PATH "split_equal_parts_2d.prototxt";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
// Compare function does not support models with multiple outputs
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_FALSE(success) << message;
}
TEST_F(SerializationTest, ModelWithMultipleLayers_MO) {
const std::string model = IR_SERIALIZATION_MODELS_PATH "addmul_abc.xml";
const std::string weights = IR_SERIALIZATION_MODELS_PATH "addmul_abc.bin";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model, weights);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}
TEST_F(SerializationTest, ModelWithMultipleLayers_ONNXImporter) {
const std::string model =
IR_SERIALIZATION_MODELS_PATH "addmul_abc.prototxt";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}
TEST_F(SerializationTest, ModelWithConstants_MO) {
const std::string model =
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml";
const std::string weights =
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model, weights);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}
TEST_F(SerializationTest, ModelWithConstants_ONNXImporter) {
const std::string model =
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.prototxt";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}
TEST_F(SerializationTest, ExperimentalDetectronROIFeatureExtractor_MO) {
const std::string model = IR_SERIALIZATION_MODELS_PATH
"experimental_detectron_roi_feature_extractor.xml";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}
TEST_F(SerializationTest, ExperimentalDetectronDetectionOutput_MO) {
const std::string model = IR_SERIALIZATION_MODELS_PATH
"experimental_detectron_detection_output.xml";
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model);
expected.serialize(m_out_xml_path, m_out_bin_path);
auto result = ie.ReadNetwork(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result.getFunction(), expected.getFunction());
ASSERT_TRUE(success) << message;
}

View File

@ -95,10 +95,10 @@ TEST(TransformationTests, LowLatencyLSTM) {
auto lstm_cell = std::make_shared<opset5::LSTMCell>(squeeze, read_value_H, read_value_C, W, R, B, 128); auto lstm_cell = std::make_shared<opset5::LSTMCell>(squeeze, read_value_H, read_value_C, W, R, B, 128);
auto assign_H = std::make_shared<opset5::Assign>(lstm_cell->output(0), variable_name_H); auto assign_H = std::make_shared<opset5::Assign>(lstm_cell->output(0), variable_name_H);
auto assign_C = std::make_shared<opset5::Assign>(lstm_cell->output(1), variable_name_C); auto assign_C = std::make_shared<opset5::Assign>(lstm_cell->output(1), variable_name_C);
auto res_1 = std::make_shared<opset5::Result>(lstm_cell->output(0));
auto unsqueeze = std::make_shared<opset5::Unsqueeze>(lstm_cell->output(0), axis); auto unsqueeze = std::make_shared<opset5::Unsqueeze>(lstm_cell->output(0), axis);
auto res_2 = std::make_shared<opset5::Result>(unsqueeze); auto res_2 = std::make_shared<opset5::Result>(unsqueeze);
f_ref = std::make_shared<ngraph::Function>(OutputVector{unsqueeze, res_1}, ParameterVector{Xi, H_t, C_t}); auto res_1 = std::make_shared<opset5::Result>(lstm_cell->output(0));
f_ref = std::make_shared<ngraph::Function>(OutputVector{res_1, res_2}, ParameterVector{Xi, H_t, C_t});
f_ref->add_sinks({assign_C, assign_H}); f_ref->add_sinks({assign_C, assign_H});
assign_H->add_control_dependency(read_value_H); assign_H->add_control_dependency(read_value_H);
assign_C->add_control_dependency(read_value_C); assign_C->add_control_dependency(read_value_C);
@ -340,10 +340,10 @@ TEST(TransformationTests, LowLatencyLSTMReshape) {
auto lstm_cell = std::make_shared<opset5::LSTMCell>(squeeze, read_value_H, read_value_C, W, R, B, 128); auto lstm_cell = std::make_shared<opset5::LSTMCell>(squeeze, read_value_H, read_value_C, W, R, B, 128);
auto assign_H = std::make_shared<opset5::Assign>(lstm_cell->output(0), variable_name_H); auto assign_H = std::make_shared<opset5::Assign>(lstm_cell->output(0), variable_name_H);
auto assign_C = std::make_shared<opset5::Assign>(lstm_cell->output(1), variable_name_C); auto assign_C = std::make_shared<opset5::Assign>(lstm_cell->output(1), variable_name_C);
auto res_1 = std::make_shared<opset5::Result>(lstm_cell->output(0));
auto unsqueeze = std::make_shared<opset5::Unsqueeze>(lstm_cell->output(0), axis); auto unsqueeze = std::make_shared<opset5::Unsqueeze>(lstm_cell->output(0), axis);
auto res_2 = std::make_shared<opset5::Result>(unsqueeze); auto res_2 = std::make_shared<opset5::Result>(unsqueeze);
f_ref = std::make_shared<ngraph::Function>(OutputVector{unsqueeze, res_1}, ParameterVector{Xi, H_t, C_t}); auto res_1 = std::make_shared<opset5::Result>(lstm_cell->output(0));
f_ref = std::make_shared<ngraph::Function>(OutputVector{res_1, res_2}, ParameterVector{Xi, H_t, C_t});
f_ref->add_sinks({assign_C, assign_H}); f_ref->add_sinks({assign_C, assign_H});
assign_H->add_control_dependency(read_value_H); assign_H->add_control_dependency(read_value_H);
assign_C->add_control_dependency(read_value_C); assign_C->add_control_dependency(read_value_C);

View File

@ -84,8 +84,12 @@ std::pair<bool, std::string> compare_functions(
* - Do not check nodes attributes (requires visitor mechanism to be completed) * - Do not check nodes attributes (requires visitor mechanism to be completed)
*/ */
const auto& f1_results = f1->get_results(); auto f1_results = f1->get_results();
const auto& f2_results = f2->get_results(); auto f2_results = f2->get_results();
auto compare_nodes_by_name = [](const std::shared_ptr<ngraph::Node> & l, const std::shared_ptr<ngraph::Node> & r)
{ return l->get_friendly_name() < r->get_friendly_name(); };
std::sort(f1_results.begin(), f1_results.end(), compare_nodes_by_name);
std::sort(f2_results.begin(), f2_results.end(), compare_nodes_by_name);
if (f1_results.size() != f2_results.size()) { if (f1_results.size() != f2_results.size()) {
return { false, "Number of results is different: " + std::to_string(f1_results.size()) + " and " + std::to_string(f2_results.size()) }; return { false, "Number of results is different: " + std::to_string(f1_results.size()) + " and " + std::to_string(f2_results.size()) };
} }

View File

@ -902,6 +902,23 @@ void op::v5::NonMaxSuppression::validate_and_infer_types()
validate(); validate();
if (boxes_ps.rank().is_static() && scores_ps.rank().is_static() && get_input_size() > 2)
{
const auto num_boxes_boxes = boxes_ps[1];
const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() &&
op::is_constant(max_output_boxes_per_class_node))
{
const auto num_boxes = num_boxes_boxes.get_length();
const auto num_classes = scores_ps[1].get_length();
const auto max_output_boxes_per_class = max_boxes_output_from_input();
out_shape[0] = Dimension(0,
std::min(num_boxes, max_output_boxes_per_class) * num_classes *
scores_ps[0].get_length());
}
}
set_output_type(0, m_output_type, out_shape); set_output_type(0, m_output_type, out_shape);
set_output_type(1, element::f32, out_shape); set_output_type(1, element::f32, out_shape);
set_output_type(2, m_output_type, Shape{1}); set_output_type(2, m_output_type, Shape{1});

View File

@ -15,9 +15,10 @@
# ****************************************************************************** # ******************************************************************************
import numpy as np import numpy as np
import pytest import pytest
from _pyngraph import PartialShape from _pyngraph import PartialShape, Dimension
import ngraph as ng import ngraph as ng
from ngraph.utils.types import make_constant_node
from tests.runtime import get_runtime from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node from tests.test_ngraph.util import run_op_node
from tests import xfail_issue_40957 from tests import xfail_issue_40957
@ -108,12 +109,12 @@ def test_non_max_suppression():
boxes_parameter = ng.parameter(boxes_shape, name="Boxes", dtype=np.float32) boxes_parameter = ng.parameter(boxes_shape, name="Boxes", dtype=np.float32)
scores_parameter = ng.parameter(scores_shape, name="Scores", dtype=np.float32) scores_parameter = ng.parameter(scores_shape, name="Scores", dtype=np.float32)
node = ng.non_max_suppression(boxes_parameter, scores_parameter) node = ng.non_max_suppression(boxes_parameter, scores_parameter, make_constant_node(1000, np.int64))
assert node.get_type_name() == "NonMaxSuppression" assert node.get_type_name() == "NonMaxSuppression"
assert node.get_output_size() == 3 assert node.get_output_size() == 3
assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 3])) assert node.get_output_partial_shape(0) == PartialShape([Dimension(0, 1000), Dimension(3)])
assert node.get_output_partial_shape(1).same_scheme(PartialShape([-1, 3])) assert node.get_output_partial_shape(1) == PartialShape([Dimension(0, 1000), Dimension(3)])
assert list(node.get_output_shape(2)) == [1] assert list(node.get_output_shape(2)) == [1]

View File

@ -691,11 +691,9 @@ TEST(type_prop, nms_v5_output_shape_2)
ASSERT_EQ(nms->get_output_element_type(0), element::i64); ASSERT_EQ(nms->get_output_element_type(0), element::i64);
ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i64); ASSERT_EQ(nms->get_output_element_type(2), element::i64);
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension(0, 30), Dimension(3)}));
EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension(0, 30), Dimension(3)}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
} }
@ -713,11 +711,8 @@ TEST(type_prop, nms_v5_output_shape_3)
ASSERT_EQ(nms->get_output_element_type(0), element::i64); ASSERT_EQ(nms->get_output_element_type(0), element::i64);
ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i64); ASSERT_EQ(nms->get_output_element_type(2), element::i64);
ASSERT_TRUE( EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension(0, 70), Dimension(3)}));
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension(0, 70), Dimension(3)}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
} }
@ -742,11 +737,9 @@ TEST(type_prop, nms_v5_output_shape_i32)
ASSERT_EQ(nms->get_output_element_type(0), element::i32); ASSERT_EQ(nms->get_output_element_type(0), element::i32);
ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i32); ASSERT_EQ(nms->get_output_element_type(2), element::i32);
ASSERT_TRUE(
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension(0, 30), Dimension(3)}));
EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension(0, 30), Dimension(3)}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
} }
@ -764,10 +757,7 @@ TEST(type_prop, nms_v5_dynamic_boxes_and_scores)
ASSERT_EQ(nms->get_output_element_type(0), element::i64); ASSERT_EQ(nms->get_output_element_type(0), element::i64);
ASSERT_EQ(nms->get_output_element_type(1), element::f32); ASSERT_EQ(nms->get_output_element_type(1), element::f32);
ASSERT_EQ(nms->get_output_element_type(2), element::i64); ASSERT_EQ(nms->get_output_element_type(2), element::i64);
ASSERT_TRUE( EXPECT_EQ(nms->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 3}));
nms->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 3})); EXPECT_EQ(nms->get_output_partial_shape(1), PartialShape({Dimension::dynamic(), 3}));
ASSERT_TRUE(
nms->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 3}));
EXPECT_EQ(nms->get_output_shape(2), (Shape{1})); EXPECT_EQ(nms->get_output_shape(2), (Shape{1}));
} }