From 0b05653d7a3c02786a16a9440b6b95bef8b31946 Mon Sep 17 00:00:00 2001 From: Mateusz Bencer Date: Mon, 21 Dec 2020 12:32:15 +0100 Subject: [PATCH] Resolved problems with ssd_resnet34_mlperf_opset10 (#3487) * Resolved problems with ssd_resnet34_1200 * removed debug code * Added correct handling onnx nodes from parent graph scope * removed unnecessary include * fixed calcution index to replace * fixed LoopParentParametersUsedInBody test * added set_friendly_name * apply Unsqueeze for each concatenated Loop output * added handling trip count with value max_int * merge from upstream/master * update xfail list * added checking is trip_count is constant --- .../transformations/sr_sub_graph_ops.cpp | 132 ++++++++++++- ngraph/core/src/op/loop.cpp | 37 ++-- .../include/onnx_import/core/graph.hpp | 9 +- .../include/onnx_import/core/graph_cache.hpp | 29 +++ .../frontend/onnx_import/src/core/graph.cpp | 55 +++--- .../onnx_import/src/core/graph_cache.cpp | 21 ++ ngraph/frontend/onnx_import/src/op/loop.cpp | 35 +++- ngraph/frontend/onnx_import/src/op/topk.cpp | 16 +- ngraph/python/tests/__init__.py | 6 +- .../python/tests/test_onnx/test_zoo_models.py | 4 +- ...op_2d_add_input_from_parent_graph.prototxt | 166 ++++++++++++++++ .../loop_2d_add_trip_count_max_int.prototxt | 177 +++++++++++++++++ ..._scope_used_in_parent_and_in_body.prototxt | 181 ++++++++++++++++++ ...11_const_k_smallest_negative_axis.prototxt | 97 ++++++++++ ngraph/test/onnx/onnx_import.in.cpp | 14 ++ .../test/onnx/onnx_import_controlflow.in.cpp | 75 ++++++-- ngraph/test/runtime/ie/unit_test.manifest | 2 + ngraph/test/type_prop/loop.cpp | 180 ++++++++++++++++- 18 files changed, 1152 insertions(+), 84 deletions(-) create mode 100644 ngraph/test/models/onnx/loop/loop_2d_add_input_from_parent_graph.prototxt create mode 100644 ngraph/test/models/onnx/loop/loop_2d_add_trip_count_max_int.prototxt create mode 100644 ngraph/test/models/onnx/loop/loop_add_node_from_parent_scope_used_in_parent_and_in_body.prototxt create mode 100644 ngraph/test/models/onnx/top_k_opset_11_const_k_smallest_negative_axis.prototxt diff --git a/inference-engine/tests/functional/inference_engine/transformations/sr_sub_graph_ops.cpp b/inference-engine/tests/functional/inference_engine/transformations/sr_sub_graph_ops.cpp index 72f71cdc200..c454c4288f7 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/sr_sub_graph_ops.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/sr_sub_graph_ops.cpp @@ -256,4 +256,134 @@ TEST(SmartReshapeTests, LoopDynamicParameters) { // concat output ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({32, 10, 10})); ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({32, 1, 1})); -} \ No newline at end of file +} + +TEST(SmartReshapeTests, LoopParentParametersUsedInBody) { + std::shared_ptr f(nullptr); + { + // That which we iterate over + auto X = std::make_shared(element::f32, PartialShape::dynamic()); + auto Y = std::make_shared(element::f32, PartialShape::dynamic()); + auto add_Y = std::make_shared(Y, + std::make_shared(ngraph::element::f32, ngraph::Shape{}, std::vector{0.f})); + auto M = std::make_shared(element::f32, PartialShape::dynamic()); + X->set_friendly_name("X"); + Y->set_friendly_name("Y"); + M->set_friendly_name("M"); + + // Set up the cell body, a function from (Xi, add_Y) -> (Zo) + // Body parameters + auto current_iteration = std::make_shared(element::i64, Shape{}); + auto Xi = std::make_shared(element::f32, PartialShape::dynamic()); + auto Yi = std::make_shared(element::f32, PartialShape::dynamic()); + auto M_body = std::make_shared(element::f32, PartialShape::dynamic()); + auto body_condition = + std::make_shared(ngraph::element::boolean, ngraph::Shape{}, true); + + auto trip_count = + std::make_shared(ngraph::element::i64, ngraph::Shape{}, 10); + auto exec_condition = + std::make_shared(ngraph::element::boolean, ngraph::Shape{}, true); + // Body + auto sum = std::make_shared(Xi, Yi); + auto Zo = std::make_shared(sum, M_body); + auto body = std::make_shared(OutputVector{Zo, body_condition, sum}, + ParameterVector{Xi, current_iteration, Yi, M_body}); + + auto loop = std::make_shared(trip_count, exec_condition); + loop->set_function(body); + loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{1, 1}); + + loop->set_sliced_input(Xi, X, 0, 1, 1, -1, 2); + loop->set_merged_input(M_body, M, Zo); + // Set invariant input which uses parameter from parent graph + loop->set_invariant_input(Yi, add_Y); + + // Output 0 is last Zo + auto out0 = loop->get_iter_value(body_condition, -1); + auto out1 = loop->get_iter_value(Zo, -1); + // Output 1 is concat of Zos + // start=0, stride=1, part_size=1, end=-1, axis=1 + auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1); + auto out3 = loop->get_iter_value(sum, -1); + + f = std::make_shared(OutputVector{out0, out1, out2, out3}, ParameterVector{X, Y, M}); + } + + InferenceEngine::CNNNetwork network(f); + ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({})); + ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible(PartialShape::dynamic())); + // concat output + ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible(PartialShape::dynamic())); + ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible(PartialShape::dynamic())); + + ASSERT_NO_THROW(network.reshape({{"X", {4, 3, 2}}, {"Y", {4, 3, 2}}, {"M", {4, 3, 2}}})); + + ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({})); + ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible({4, 3, 2})); + // concat output + ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({4, 30, 2})); + ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({4, 3, 2})); +} + +TEST(SmartReshapeTests, TensorIteratorParentParameterUsedInBody) { + std::shared_ptr f(nullptr); + { + // That which we iterate over + auto X = std::make_shared(element::f32, Shape{1, 1, 1}); + auto Y = std::make_shared(element::f32, Shape{1, 1, 1}); + auto add_Y = std::make_shared(Y, + std::make_shared(ngraph::element::f32, ngraph::Shape{}, std::vector{0.f})); + auto M = std::make_shared(element::f32, Shape{1, 1, 1}); + X->set_friendly_name("X"); + Y->set_friendly_name("Y"); + M->set_friendly_name("M"); + + // Set up the cell body, a function from (Xi, add_Y) -> (Zo) + // Body parameters + auto Xi = std::make_shared(element::f32, PartialShape::dynamic()); + auto Yi = std::make_shared(element::f32, PartialShape::dynamic()); + auto M_body = std::make_shared(element::f32, PartialShape::dynamic()); + auto body_condition = + std::make_shared(ngraph::element::boolean, ngraph::Shape{}, true); + + // Body + auto sum = std::make_shared(Xi, Yi); + auto Zo = std::make_shared(sum, M_body); + auto body = std::make_shared(OutputVector{Zo, body_condition, sum}, + ParameterVector{Xi, Yi, M_body}); + + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_function(body); + + tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 2); + tensor_iterator->set_merged_input(M_body, M, Zo); + // Set invariant input which uses parameter from parent graph + tensor_iterator->set_invariant_input(Yi, add_Y); + + // Output 0 is last Zo + auto out0 = tensor_iterator->get_iter_value(body_condition, -1); + auto out1 = tensor_iterator->get_iter_value(Zo, -1); + // Output 1 is concat of Zos + // start=0, stride=1, part_size=1, end=-1, axis=1 + auto out2 = tensor_iterator->get_concatenated_slices(Zo, 0, 1, 1, -1, 1); + auto out3 = tensor_iterator->get_iter_value(sum, -1); + + f = std::make_shared(OutputVector{out0, out1, out2, out3}, ParameterVector{X, Y, M}); + } + + InferenceEngine::CNNNetwork network(f); + ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({})); + ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible({1, 1, 1})); + // concat output (seq len = 1, so it means num_iter = 1) + ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({1, 1, 1})); + ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({1, 1, 1})); + + ASSERT_NO_THROW(network.reshape({{"X", {32, 1, 10}}, {"Y", {1, 1, 1}}, {"M", {32, 1, 10}}})); + + ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({})); + ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible({32, 1, 10})); + // concat output + ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({32, 10, 10})); + ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({32, 1, 1})); +} diff --git a/ngraph/core/src/op/loop.cpp b/ngraph/core/src/op/loop.cpp index e00db9a3148..65892380db8 100644 --- a/ngraph/core/src/op/loop.cpp +++ b/ngraph/core/src/op/loop.cpp @@ -247,15 +247,22 @@ void op::v5::Loop::validate_and_infer_types() as_type_ptr(output_description)) { const auto& body_value_partial_shape = body_value.get_partial_shape(); - set_output_type(index, body_value.get_element_type(), PartialShape::dynamic()); - if (body_value_partial_shape.is_static()) + if (body_value_partial_shape.rank().is_dynamic()) + { + set_output_type(index, body_value.get_element_type(), PartialShape::dynamic()); + } + else { - auto body_value_shape = body_value_partial_shape.to_shape(); auto axis = concat_output_description->m_axis; - Shape out_shape{body_value_shape}; + NODE_VALIDATION_CHECK(this, + axis < body_value_partial_shape.rank().get_length(), + "Concatenation axis must be less than sliced output rank"); - if (body_value_shape.empty()) + PartialShape out_shape{body_value_partial_shape}; + + if (body_value_partial_shape.is_static() && + ngraph::is_scalar(body_value_partial_shape.to_shape())) { NODE_VALIDATION_CHECK( this, @@ -266,23 +273,23 @@ void op::v5::Loop::validate_and_infer_types() out_shape = Shape(1); } - if (m_num_iterations != -1) + if (m_num_iterations != -1 && body_value_partial_shape[axis].is_static()) { - out_shape[axis] = m_num_iterations * body_value_shape[axis]; + out_shape[axis] = + m_num_iterations * body_value_partial_shape[axis].get_length(); if (zero_number_of_iter) { - out_shape.at(0) = 0; + out_shape[0] = 0; } - set_output_type(index, body_value.get_element_type(), out_shape); } - } - else - { - set_output_type(index, - body_value.get_element_type(), - PartialShape::dynamic(body_value.get_partial_shape().rank())); + else + { + out_shape[axis] = Dimension::dynamic(); + } + set_output_type(index, body_value.get_element_type(), out_shape); } } + else if (auto body_output_description = as_type_ptr(output_description)) { diff --git a/ngraph/frontend/onnx_import/include/onnx_import/core/graph.hpp b/ngraph/frontend/onnx_import/include/onnx_import/core/graph.hpp index 40a9da798c1..f8c6e7c17be 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/core/graph.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/core/graph.hpp @@ -67,10 +67,10 @@ namespace ngraph protected: ParameterVector m_parameters; + std::unique_ptr m_cache; private: const ONNX_NAMESPACE::GraphProto* m_graph_proto; - std::unique_ptr m_cache; std::vector m_nodes; std::vector m_inputs; std::vector m_outputs; @@ -91,6 +91,13 @@ namespace ngraph Subgraph(const ONNX_NAMESPACE::GraphProto& proto, Model& model, const Graph& parent_graph); + + /// \brief Return outputs which are on the edge the subgraph and the parent graph. + /// \return Vector of edge nodes from parent scope. + const std::vector> get_outputs_from_parent() const; + + private: + std::vector> m_outputs_from_parent; }; inline std::ostream& operator<<(std::ostream& outs, const Graph& graph) diff --git a/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp b/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp index cb31f64b600..13297845e78 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp @@ -25,6 +25,17 @@ namespace ngraph { namespace onnx_import { + /// \brief Enum which determines scope (visibility) of nodes in GraphCache. + enum class NodeScope + { + // in parent graph scope + ParentGraph = 1, + // in subgraph scope + SubGraph, + // not available at all + Lack + }; + /// \brief GraphCache stores and provides access to ONNX graph initializers. class GraphCache { @@ -53,6 +64,16 @@ namespace ngraph /// \return true if the node named `name` exist in the cache, false otherwise. virtual bool contains(const std::string& name) const; + /// \brief Return NodeScope enum which determines scope of the node. + /// \note If the method is called on GraphCache the ParentGraph enum + /// value is retunred always. + /// + /// \param[in] name The name of the node. + /// + /// \return SubGraph if node belongs to SubgraphCache, ParentGraph if + /// is avalible in parent_graph_cache, otherwise Lack + virtual NodeScope node_scope(const std::string& name) const; + private: std::map> m_graph_cache_map; }; @@ -82,6 +103,14 @@ namespace ngraph /// (subgraph or parent graph), false otherwise. bool contains(const std::string& name) const override; + /// \brief Return NodeScope enum which determines scope of the node. + /// + /// \param[in] name The name of the node. + /// + /// \return SubGraph if the node belongs to SubgraphCache, ParentGraph if + /// is avalible in parent_graph_cache, otherwise Lack + NodeScope node_scope(const std::string& name) const override; + private: const GraphCache* m_parent_graph_cache; }; diff --git a/ngraph/frontend/onnx_import/src/core/graph.cpp b/ngraph/frontend/onnx_import/src/core/graph.cpp index 30c4f56541e..605fdfba27c 100644 --- a/ngraph/frontend/onnx_import/src/core/graph.cpp +++ b/ngraph/frontend/onnx_import/src/core/graph.cpp @@ -315,39 +315,48 @@ namespace ngraph model, std::unique_ptr(new SubgraphCache(parent_graph.get_graph_cache()))) { - std::vector> subgraph_root_nodes; - const auto& outputs = as_result_vector(get_ng_outputs()); - for (auto& out : outputs) + // find all nodes on edge parent graph-subgraph + // (it means input of node from parent graph, output from subgraph) + for (const auto& node_proto : proto.node()) { - subgraph_root_nodes.push_back(out); - } - const auto& params = get_ng_parameters(); - for (auto& param : params) - { - subgraph_root_nodes.push_back(param); - } - const auto subgraph_nodes = topological_sort(subgraph_root_nodes); - - const auto& parent_graph_parameters = parent_graph.get_ng_parameters(); - for (const auto& node : subgraph_nodes) - { - if (op::is_parameter(node)) + int input_index = 0; + for (const auto& in_name : node_proto.input()) { - const auto sub_it = std::find(m_parameters.begin(), m_parameters.end(), node); - // not present as subgraph parameter - if (sub_it == m_parameters.end()) + if (m_cache->node_scope(in_name) == NodeScope::ParentGraph) { - const auto parent_it = std::find( - parent_graph_parameters.begin(), parent_graph_parameters.end(), node); - if (parent_it != m_parameters.end()) + const auto& from_parent_node = m_cache->get_node(in_name); + // constants are skipped + if (!ngraph::is_type( + from_parent_node.get_node_shared_ptr())) { - m_parameters.push_back(*parent_it); + for (const auto& out_name : node_proto.output()) + { + if (m_cache->node_scope(out_name) == NodeScope::SubGraph) + { + auto out_node_to_replace_input = m_cache->get_node(out_name); + auto new_param = std::make_shared( + from_parent_node.get_element_type(), + from_parent_node.get_partial_shape()); + // replace input from parent scope with parameter + out_node_to_replace_input.get_node() + ->input(input_index) + .replace_source_output(new_param); + m_parameters.push_back(new_param); + m_outputs_from_parent.push_back(from_parent_node); + } + } } } + ++input_index; } } } + const std::vector> Subgraph::get_outputs_from_parent() const + { + return m_outputs_from_parent; + } + } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/core/graph_cache.cpp b/ngraph/frontend/onnx_import/src/core/graph_cache.cpp index 54305b0ecec..2155bb0e01d 100644 --- a/ngraph/frontend/onnx_import/src/core/graph_cache.cpp +++ b/ngraph/frontend/onnx_import/src/core/graph_cache.cpp @@ -43,6 +43,11 @@ namespace ngraph return (m_graph_cache_map.count(name) > 0); } + NodeScope GraphCache::node_scope(const std::string& name) const + { + return contains(name) ? NodeScope::ParentGraph : NodeScope::Lack; + } + SubgraphCache::SubgraphCache(const GraphCache& parent_graph_cache) : m_parent_graph_cache{&parent_graph_cache} { @@ -71,5 +76,21 @@ namespace ngraph return GraphCache::contains(name) || m_parent_graph_cache->contains(name); } + NodeScope SubgraphCache::node_scope(const std::string& name) const + { + if (GraphCache::contains(name)) + { + return NodeScope::SubGraph; + } + else if (m_parent_graph_cache->contains(name)) + { + return NodeScope::ParentGraph; + } + else + { + return NodeScope::Lack; + } + } + } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/op/loop.cpp b/ngraph/frontend/onnx_import/src/op/loop.cpp index 2039b12b46f..faa3419a6b7 100644 --- a/ngraph/frontend/onnx_import/src/op/loop.cpp +++ b/ngraph/frontend/onnx_import/src/op/loop.cpp @@ -87,7 +87,12 @@ namespace ngraph // optional inputs Output trip_count; - if (ngraph::op::is_null(ng_inputs.at(0))) // trip count skipped + // trip count skipped or has value max(int64_t) means infinitive loop + if (ngraph::op::is_null(ng_inputs.at(0)) || + (ngraph::op::is_constant(ng_inputs.at(0).get_node_shared_ptr()) && + as_type_ptr(ng_inputs.at(0).get_node_shared_ptr()) + ->cast_vector()[0] == + std::numeric_limits::max())) { // -1 means infinite Loop trip_count = ngraph::op::Constant::create(ngraph::element::i64, {1}, {-1}); @@ -132,17 +137,13 @@ namespace ngraph const int64_t concat_axis = 0; const auto concat_axis_const = ngraph::op::Constant::create(ngraph::element::i64, {1}, {concat_axis}); - // provide scalar handing for scan outputs + // add dimension along which scan outputs will be concatenated for (size_t i = loop_carried_dependencies.size() + 1; i < body_outputs.size(); ++i) { - auto body_output_shape = body_outputs[i].get_partial_shape(); - if (body_output_shape.is_static() && - ngraph::is_scalar(body_output_shape.to_shape())) - { - body_outputs[i] = std::make_shared( - body_outputs[i], concat_axis_const); - } + const auto& body_output_shape = body_outputs[i].get_partial_shape(); + body_outputs[i] = std::make_shared( + body_outputs[i], concat_axis_const); } const auto& body_loop_out_cond = body_outputs.at(0).get_node_shared_ptr(); @@ -193,6 +194,22 @@ namespace ngraph final_values.push_back(loop->get_iter_value(*body_outputs_it++, -1)); } + const auto& outputs_from_parent = body_graph.get_outputs_from_parent(); + CHECK_VALID_NODE(node, + std::distance(body_inputs_it, body_inputs.end()) == + outputs_from_parent.size(), + "Expected number of invariant parameters is" + " not equal number of provided outputs from parent scope"); + + // Set-up parameters from parent graph which are not changed during Loop's + // iterations + for (auto out_from_parent_it = outputs_from_parent.begin(); + body_inputs_it != body_inputs.end(); + ++body_inputs_it, ++out_from_parent_it) + { + loop->set_invariant_input(*body_inputs_it, *out_from_parent_it); + } + // Set-up scan outputs OutputVector scan_outputs; for (; body_outputs_it != body_outputs.end(); body_outputs_it++) diff --git a/ngraph/frontend/onnx_import/src/op/topk.cpp b/ngraph/frontend/onnx_import/src/op/topk.cpp index 8dfb1ecb4ec..a6ae196a435 100644 --- a/ngraph/frontend/onnx_import/src/op/topk.cpp +++ b/ngraph/frontend/onnx_import/src/op/topk.cpp @@ -29,16 +29,6 @@ namespace { - /// \return Parse node attribute value for axis and adjust for negative value if needed. - std::int64_t get_axis(const ngraph::onnx_import::Node& node) - { - std::int64_t axis{node.get_attribute_value("axis", -1)}; - - const auto data = node.get_ng_inputs().at(0); - const auto data_rank = data.get_partial_shape().rank(); - return ngraph::normalize_axis(node.get_description(), axis, data_rank); - } - /// \return Return the second input to the TopK node reshaped to a scalar. ngraph::Output get_k(const ngraph::onnx_import::Node& node) { @@ -64,7 +54,7 @@ namespace ngraph auto data = node.get_ng_inputs().at(0); std::int64_t k{node.get_attribute_value("k")}; auto k_node = default_opset::Constant::create(element::i64, Shape{}, {k}); - auto axis = get_axis(node); + const std::int64_t axis{node.get_attribute_value("axis", -1)}; std::shared_ptr top_k = std::make_shared( data, @@ -84,7 +74,7 @@ namespace ngraph { auto data = node.get_ng_inputs().at(0); auto k = get_k(node); - auto axis = get_axis(node); + const std::int64_t axis{node.get_attribute_value("axis", -1)}; std::shared_ptr top_k = std::make_shared( data, @@ -107,7 +97,7 @@ namespace ngraph auto k = get_k(node); // Process attributes - const auto axis = get_axis(node); + const std::int64_t axis{node.get_attribute_value("axis", -1)}; const auto largest = node.get_attribute_value("largest", 1); const auto sorted = node.get_attribute_value("sorted", 1); diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index aca4c4409f6..3ae971b3157 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -137,10 +137,8 @@ xfail_issue_38714 = xfail_test(reason="RuntimeError: While validating ONNX node "Argument element types are inconsistent.") xfail_issue_43742 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" "If") -xfail_issue_43439 = xfail_test(reason="Check 'tensor_rank.is_static()' failed at " - "ngraph/core/src/validation_util.cpp:884:" - "map_1/while/select_bboxes/sort_bboxes_10/TopKV2 " - "Rank must be static in order to normalize negative axis=-1") +xfail_issue_45457 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v5::Loop" + "Not constant termination condition body output is not supported") xfail_issue_38715 = xfail_test(reason="RuntimeError: While validating ONNX node '':" "While validating node 'v1::OneHot OneHot_" "(Convert_13525[0]:i64{3}, depth[0]:f32{}," diff --git a/ngraph/python/tests/test_onnx/test_zoo_models.py b/ngraph/python/tests/test_onnx/test_zoo_models.py index 2dd566631b0..11c55c2ecf9 100644 --- a/ngraph/python/tests/test_onnx/test_zoo_models.py +++ b/ngraph/python/tests/test_onnx/test_zoo_models.py @@ -29,7 +29,7 @@ from tests import ( xfail_issue_38701, xfail_issue_43742, xfail_issue_43380, - xfail_issue_43439, + xfail_issue_45457, xfail_issue_39684, xfail_issue_40957, xfail_issue_39685, @@ -152,7 +152,6 @@ if len(zoo_models) > 0: # Model MSFT (xfail_issue_43742, "test_MSFT_opset10_mlperf_ssd_mobilenet_300_ssd_mobilenet_v1_coco_2018_01_28_cpu"), - (xfail_issue_43439, "test_MSFT_opset10_mlperf_ssd_resnet34_1200_ssd_resnet34_mAP_20.2_cpu"), (xfail_issue_37957, "test_MSFT_opset10_mask_rcnn_keras_mask_rcnn_keras_cpu"), ] for test_case in import_xfail_list: @@ -178,6 +177,7 @@ if len(zoo_models) > 0: (xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_mask_rcnn_model_MaskRCNN_10_mask_rcnn_R_50_FPN_1x_cpu"), (xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_faster_rcnn_model_FasterRCNN_10_faster_rcnn_R_50_FPN_1x_cpu"), (xfail_issue_43380, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov3_model_tiny_yolov3_11_yolov3_tiny_cpu"), + (xfail_issue_45457, "test_MSFT_opset10_mlperf_ssd_resnet34_1200_ssd_resnet34_mAP_20.2_cpu"), # Model MSFT (xfail_issue_37973, "test_MSFT_opset7_tf_inception_v2_model_cpu"), diff --git a/ngraph/test/models/onnx/loop/loop_2d_add_input_from_parent_graph.prototxt b/ngraph/test/models/onnx/loop/loop_2d_add_input_from_parent_graph.prototxt new file mode 100644 index 00000000000..46c928876d2 --- /dev/null +++ b/ngraph/test/models/onnx/loop/loop_2d_add_input_from_parent_graph.prototxt @@ -0,0 +1,166 @@ +ir_version: 6 +producer_name: "nGraph ONNX Importer" +graph { + name: "basic loop" + node { + input: "trip_count" + input: "" + input: "a_init" + output: "a_final" + output: "a_values" + op_type: "Loop" + attribute { + name: "body" + g { + node { + input: "a_in" + input: "b" + output: "current_a" + name: "loop_body_add" + op_type: "Add" + } + node { + input: "cond_in" + output: "cond_out" + name: "cond_identity" + op_type: "Identity" + } + node { + input: "current_a" + output: "a_out" + name: "output_accumulator" + op_type: "Identity" + } + name: "simple add" + input { + name: "i" + type { + tensor_type { + elem_type: 7 + shape { + } + } + } + } + input { + name: "cond_in" + type { + tensor_type { + elem_type: 9 + shape { + } + } + } + } + input { + name: "a_in" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "cond_out" + type { + tensor_type { + elem_type: 9 + shape { + } + } + } + } + output { + name: "current_a" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } + output { + name: "a_out" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } + } + type: GRAPH + } + } + initializer { + dims: 1 + data_type: 7 + int64_data: 3 + name: "trip_count" + } + input { + name: "a_init" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "b" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "a_final" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } + output { + name: "a_values" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + version: 11 +} diff --git a/ngraph/test/models/onnx/loop/loop_2d_add_trip_count_max_int.prototxt b/ngraph/test/models/onnx/loop/loop_2d_add_trip_count_max_int.prototxt new file mode 100644 index 00000000000..227e00d9e56 --- /dev/null +++ b/ngraph/test/models/onnx/loop/loop_2d_add_trip_count_max_int.prototxt @@ -0,0 +1,177 @@ +ir_version: 6 +producer_name: "nGraph ONNX Importer" +graph { + name: "basic loop" + node { + input: "trip_count" + input: "cond_in" + input: "a_init" + output: "a_final" + output: "a_values" + op_type: "Loop" + attribute { + name: "body" + g { + node { + input: "a_in" + input: "b" + output: "current_a" + name: "loop_body_add" + op_type: "Add" + } + node { + input: "i" + input: "threshold" + output: "cond_out" + name: "condition_calc" + op_type: "Less" + } + node { + input: "current_a" + output: "a_out" + name: "output_accumulator" + op_type: "Identity" + } + name: "simple add" + initializer { + dims: 1 + dims: 2 + data_type: 1 + float_data: 1 + float_data: 1 + name: "b" + } + input { + name: "i" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cond" + type { + tensor_type { + elem_type: 9 + } + } + } + input { + name: "a_in" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "cond_out" + type { + tensor_type { + elem_type: 9 + } + } + } + output { + name: "current_a" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } + output { + name: "a_out" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } + } + type: GRAPH + } + } + initializer { + dims: 1 + data_type: 7 + int64_data: 5 + name: "threshold" + } + initializer { + dims: 1 + data_type: 7 + int64_data: 9223372036854775807 + name: "trip_count" + } + input { + name: "cond_in" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "a_init" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "a_final" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } + output { + name: "a_values" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + version: 11 +} diff --git a/ngraph/test/models/onnx/loop/loop_add_node_from_parent_scope_used_in_parent_and_in_body.prototxt b/ngraph/test/models/onnx/loop/loop_add_node_from_parent_scope_used_in_parent_and_in_body.prototxt new file mode 100644 index 00000000000..ad594c795e4 --- /dev/null +++ b/ngraph/test/models/onnx/loop/loop_add_node_from_parent_scope_used_in_parent_and_in_body.prototxt @@ -0,0 +1,181 @@ +ir_version: 6 +producer_name: "nGraph ONNX Importer" +graph { + name: "basic loop" + node { + input: "parent_input" + input: "scale" + name: "mul_node" + op_type: "Mul" + output: "b" + } + node { + input: "parent_input" + input: "b" + name: "parent_add_node" + op_type: "Add" + output: "c" + } + node { + input: "trip_count" + input: "cond_in" + input: "a_init" + output: "a_final" + output: "a_values" + op_type: "Loop" + attribute { + name: "body" + g { + name: "simple add" + node { + input: "b" + input: "a_in" + output: "current_a" + name: "loop_body_add" + op_type: "Add" + } + node { + input: "cond" + output: "cond_out" + name: "cond_identity" + op_type: "Identity" + } + node { + input: "current_a" + output: "a_out" + name: "output_accumulator" + op_type: "Identity" + } + input { + name: "i" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cond" + type { + tensor_type { + elem_type: 9 + } + } + } + input { + name: "a_in" + type { + tensor_type { + elem_type: 1 + } + } + } + output { + name: "cond_out" + type { + tensor_type { + elem_type: 9 + } + } + } + output { + name: "current_a" + type { + tensor_type { + elem_type: 1 + } + } + } + output { + name: "a_out" + type { + tensor_type { + elem_type: 1 + } + } + } + } + type: GRAPH + } + } + initializer { + dims: 1 + data_type: 7 + int64_data: 3 + name: "trip_count" + } + initializer { + dims: 1 + data_type: 9 + int32_data: 00000001 + name: "cond_in" + } + initializer { + dims: 1 + data_type: 1 + float_data: 2 + name: "scale" + } + + input { + name: "a_init" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "parent_input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "a_final" + type { + tensor_type { + elem_type: 1 + } + } + } + output { + name: "a_values" + type { + tensor_type { + elem_type: 1 + } + } + } + output { + name: "c" + type { + tensor_type { + elem_type: 1 + } + } + } +} +opset_import { + version: 11 +} diff --git a/ngraph/test/models/onnx/top_k_opset_11_const_k_smallest_negative_axis.prototxt b/ngraph/test/models/onnx/top_k_opset_11_const_k_smallest_negative_axis.prototxt new file mode 100644 index 00000000000..18d24404971 --- /dev/null +++ b/ngraph/test/models/onnx/top_k_opset_11_const_k_smallest_negative_axis.prototxt @@ -0,0 +1,97 @@ +ir_version: 5 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "x" + input: "k" + output: "values" + output: "indices" + op_type: "TopK" + attribute { + name: "axis" + i: -1 + type: INT + } + attribute { + name: "largest" + i: 0 + type: INT + } + attribute { + name: "sorted" + i: 1 + type: INT + } + } + name: "test_top_k" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "k" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + initializer { + dims: 1 + data_type: 7 + int64_data: 3 + name: "k" + } + output { + name: "values" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "indices" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 11 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index ac1d1702e68..e33db620546 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -2308,6 +2308,20 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_top_k_opset_11_const_k_smallest) test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_top_k_opset_11_const_k_smallest_negative_axis) +{ + auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/top_k_opset_11_const_k_smallest_negative_axis.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input({0, 1, 2, 3, 4, 5, 6, 7, 11, 10, 9, 8}); + + test_case.add_expected_output(Shape{3, 3}, {0, 1, 2, 4, 5, 6, 8, 9, 10}); // values + test_case.add_expected_output(Shape{3, 3}, + {0, 1, 2, 0, 1, 2, 3, 2, 1}); // indices + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_model_acosh) { auto function = diff --git a/ngraph/test/onnx/onnx_import_controlflow.in.cpp b/ngraph/test/onnx/onnx_import_controlflow.in.cpp index 827c5b4d716..278371bb883 100644 --- a/ngraph/test/onnx/onnx_import_controlflow.in.cpp +++ b/ngraph/test/onnx/onnx_import_controlflow.in.cpp @@ -60,14 +60,14 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add) EXPECT_EQ(function->get_output_shape(0), (Shape{1, 2})); EXPECT_EQ(function->get_output_element_type(1), ngraph::element::f32); EXPECT_TRUE(function->get_output_partial_shape(1).is_static()); - EXPECT_EQ(function->get_output_shape(1), (Shape{3, 2})); + EXPECT_EQ(function->get_output_shape(1), (Shape{3, 1, 2})); auto test_case = test::TestCase(function); // a_init test_case.add_input({0.f, 0.f}); test_case.add_expected_output(Shape{1, 2}, {3.f, 3.f}); - test_case.add_expected_output(Shape{3, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f}); + test_case.add_expected_output(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f}); test_case.run(); } @@ -89,7 +89,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_co test_case.add_expected_output(Shape{1, 2}, {6.f, 6.f}); test_case.add_expected_output( - Shape{6, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f}); + Shape{6, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_max_int) +{ + const auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_trip_count_max_int.prototxt")); + + auto test_case = test::TestCase(function); + // termination condition + test_case.add_input({true}); + // a_init + test_case.add_input({0.f, 0.f}); + + test_case.add_expected_output(Shape{1, 2}, {6.f, 6.f}); + test_case.add_expected_output( + Shape{6, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f}); test_case.run(); } @@ -140,7 +157,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_const_no_identity_terminat test_case.add_input({0.f, 0.f}); test_case.add_expected_output(Shape{1, 2}, {4.f, 4.f}); - test_case.add_expected_output(Shape{4, 2}, {1, 1, 2, 2, 3, 3, 4, 4}); + test_case.add_expected_output(Shape{4, 1, 2}, {1, 1, 2, 2, 3, 3, 4, 4}); test_case.run(); } @@ -182,7 +199,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_both_cond_and_trip_count_a test_case.add_expected_output(Shape{1, 2}, {6.f, 6.f}); test_case.add_expected_output( - Shape{6, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f}); + Shape{6, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f}); test_case.run(); } @@ -220,7 +237,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_initializer_from_parent_s test_case.add_input({0.f, 0.f}); test_case.add_expected_output(Shape{1, 2}, {6.f, 6.f}); - test_case.add_expected_output(Shape{3, 2}, {2.f, 2.f, 4.f, 4.f, 6.f, 6.f}); + test_case.add_expected_output(Shape{3, 1, 2}, {2.f, 2.f, 4.f, 4.f, 6.f, 6.f}); test_case.run(); } @@ -234,7 +251,26 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope) test_case.add_input({0.f, 0.f}); test_case.add_expected_output(Shape{1, 2}, {12.f, 12.f}); - test_case.add_expected_output(Shape{3, 2}, {4.f, 4.f, 8.f, 8.f, 12.f, 12.f}); + test_case.add_expected_output(Shape{3, 1, 2}, {4.f, 4.f, 8.f, 8.f, 12.f, 12.f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, + onnx_controlflow_loop_add_node_from_parent_scope_used_in_parent_and_in_body) +{ + const auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, + "onnx/loop/loop_add_node_from_parent_scope_used_in_parent_and_in_body.prototxt")); + + auto test_case = test::TestCase(function); + // a_init + test_case.add_input({0.f, 0.f}); + // parent_input + test_case.add_input({3.f}); + + test_case.add_expected_output(Shape{1, 2}, {18.f, 18.f}); + test_case.add_expected_output(Shape{3, 1, 2}, {6.f, 6.f, 12.f, 12.f, 18.f, 18.f}); + test_case.add_expected_output(Shape{1}, {9.f}); test_case.run(); } @@ -268,7 +304,23 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_value_the_same_node_from_ test_case.add_input({0.f, 0.f}); test_case.add_expected_output(Shape{1, 2}, {3.f, 3.f}); - test_case.add_expected_output(Shape{3, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f}); + test_case.add_expected_output(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_input_from_parent_graph) +{ + const auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/loop/loop_2d_add_input_from_parent_graph.prototxt")); + + auto test_case = test::TestCase(function); + // a_init + test_case.add_input({0.f, 0.f}); + // b input + test_case.add_input({1.f, 1.f}); + + test_case.add_expected_output(Shape{1, 2}, {3.f, 3.f}); + test_case.add_expected_output(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f}); test_case.run(); } @@ -321,7 +373,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add_const_cond) test_case.add_input({0.f, 0.f}); test_case.add_expected_output(Shape{1, 2}, {3.f, 3.f}); - test_case.add_expected_output(Shape{3, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f}); + test_case.add_expected_output(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f}); test_case.run(); } @@ -379,8 +431,9 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_and_cond_skippe EXPECT_TRUE(function->get_output_partial_shape(0).is_static()); EXPECT_EQ(function->get_output_shape(0), (Shape{1, 2})); EXPECT_EQ(function->get_output_element_type(1), ngraph::element::f32); - // scan_outputs shape is not know if trip_count and termination condition is not determined - EXPECT_TRUE(function->get_output_partial_shape(1).rank().is_dynamic()); + EXPECT_TRUE(function->get_output_partial_shape(1).rank().is_static()); + EXPECT_EQ(function->get_output_partial_shape(1).rank(), 3); + EXPECT_EQ(function->get_output_partial_shape(1), (PartialShape{Dimension::dynamic(), 1, 2})); } // infinitive loop execution diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index 706abec717d..9fbf986df28 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -71,6 +71,7 @@ onnx_model_split_equal_parts_2d onnx_model_split_variable_parts_2d onnx_top_k_opset_10_const_k onnx_top_k_opset_11_const_k_smallest +onnx_top_k_opset_11_const_k_smallest_negative_axis split_1d split_2d_axis_0 split_2d_axis_1 @@ -1520,6 +1521,7 @@ IE_GPU.onnx_model_fake_quantize_nonconst_inputs_infer # Not supported dynamic shapes cases for Loop onnx_controlflow_loop_2d_no_identity_termination_cond onnx_controlflow_loop_2d_no_identity_termination_cond_false +onnx_controlflow_loop_2d_trip_count_max_int onnx_controlflow_loop_2d_const_no_identity_termination_cond onnx_controlflow_loop_2d_both_cond_and_trip_count_as_inputs onnx_controlflow_loop_no_variadic_inputs_and_outputs diff --git a/ngraph/test/type_prop/loop.cpp b/ngraph/test/type_prop/loop.cpp index f4dfe846d88..fd1f4d91248 100644 --- a/ngraph/test/type_prop/loop.cpp +++ b/ngraph/test/type_prop/loop.cpp @@ -334,7 +334,7 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_static_shapes // trip_count = 10 // execution_condition = true // body_condition is not a Constant -// concat output will be dynamic, another outputs are static +// concat output has only dynamic rank, another outputs are static TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shapes) { // That which we iterate over @@ -397,7 +397,7 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shape // Output 0 is last Zo auto out0 = loop->get_iter_value(body_condition, -1); auto out1 = loop->get_iter_value(Zo, -1); - auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1); + auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 0); // check output descriptors for (auto& desc : loop->get_output_descriptions()) @@ -422,9 +422,9 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shape auto result2 = make_shared(out2); Shape out0_shape{1}; Shape out1_shape{1}; - PartialShape out2_shape{PartialShape::dynamic()}; + PartialShape out2_shape{PartialShape::dynamic(1)}; - auto results = ResultVector{result0, result1}; + auto results = ResultVector{result0, result1, result2}; auto f = make_shared(results, ParameterVector{X, Y, M}); EXPECT_EQ(result0->get_output_shape(0), out0_shape); EXPECT_EQ(result1->get_output_shape(0), out1_shape); @@ -435,6 +435,176 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shape EXPECT_EQ(loop->get_output_partial_shape(2), out2_shape); } +// trip_count = 10 +// execution_condition = true +// body_condition is not a Constant +// inputs have partially known shape +// concat output has dynamic dimension on axis position, another outputs are static +TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_partially_dynamic_shapes) +{ + // That which we iterate over + auto X = + make_shared(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()}); + auto Y = + make_shared(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()}); + auto M = make_shared(element::f32, Shape{1}); + + // Set up the cell body, a function from (Xi, Yi) -> (Zo) + // Body parameters + auto current_iteration = make_shared(element::i64, Shape{1}); + auto Xi = make_shared(element::f32, PartialShape::dynamic()); + auto Yi = make_shared(element::f32, PartialShape::dynamic()); + auto M_body = make_shared(element::f32, PartialShape::dynamic()); + auto condition_const = + std::make_shared(ngraph::element::f32, ngraph::Shape{1}, 10); + auto body_condition = std::make_shared(M_body, condition_const); + + auto trip_count = + std::make_shared(ngraph::element::i64, ngraph::Shape{1}, 10); + auto exec_condition = std::make_shared( + ngraph::element::boolean, ngraph::Shape{1}, true); + // Body + auto sum = make_shared(Xi, Yi); + auto Zo = make_shared(sum, M_body); + auto body = make_shared(OutputVector{body_condition, Zo}, + ParameterVector{current_iteration, Xi, Yi, M_body}); + + auto loop = make_shared(trip_count, exec_condition); + loop->set_function(body); + loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0}); + + loop->set_invariant_input(Xi, X); + loop->set_invariant_input(Yi, Y); + loop->set_merged_input(M_body, M, Zo); + + // check input descriptors + for (auto& desc : loop->get_input_descriptions()) + { + auto type_info = desc->get_type_info(); + if (std::strcmp(type_info.name, "InvariantInputDescription") == 0) + { + auto input_desc = + as_type_ptr(desc); + EXPECT_NE(input_desc, nullptr); + } + else if (std::strcmp(type_info.name, "SliceInputDescription") == 0) + { + auto input_desc = + as_type_ptr(desc); + EXPECT_NE(input_desc, nullptr); + } + else if (std::strcmp(type_info.name, "MergedInputDescription") == 0) + { + auto input_desc = + as_type_ptr(desc); + EXPECT_NE(input_desc, nullptr); + } + } + + // Output 0 is last Zo + auto out0 = loop->get_iter_value(body_condition, -1); + auto out1 = loop->get_iter_value(Zo, -1); + // axis=1 so sliced output on this dimension will be dynamic + auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1); + + // check output descriptors + for (auto& desc : loop->get_output_descriptions()) + { + auto type_info = desc->get_type_info(); + if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0) + { + auto output_desc = + as_type_ptr(desc); + EXPECT_NE(output_desc, nullptr); + } + else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0) + { + auto output_desc = + as_type_ptr(desc); + EXPECT_NE(output_desc, nullptr); + } + } + + auto result0 = make_shared(out0); + auto result1 = make_shared(out1); + auto result2 = make_shared(out2); + Shape out0_shape{1}; + PartialShape out1_shape{1, 2, 3, Dimension::dynamic()}; + PartialShape out2_shape{1, Dimension::dynamic(), 3, Dimension::dynamic()}; + + auto results = ResultVector{result0, result1, result2}; + auto f = make_shared(results, ParameterVector{X, Y, M}); + EXPECT_EQ(result0->get_output_shape(0), out0_shape); + EXPECT_EQ(result1->get_output_partial_shape(0), out1_shape); + EXPECT_EQ(result2->get_output_partial_shape(0), out2_shape); + + EXPECT_EQ(loop->get_output_shape(0), out0_shape); + EXPECT_EQ(loop->get_output_partial_shape(1), out1_shape); + EXPECT_EQ(loop->get_output_partial_shape(2), out2_shape); +} + +// trip_count = 10 +// execution_condition = true +// body_condition is not a Constant +// inputs have partially known shape +// Axis of silced output is set as incorrect +TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_incorrect_sliced_output_axis) +{ + // That which we iterate over + auto X = + make_shared(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()}); + auto Y = + make_shared(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()}); + auto M = make_shared(element::f32, Shape{1}); + + // Set up the cell body, a function from (Xi, Yi) -> (Zo) + // Body parameters + auto current_iteration = make_shared(element::i64, Shape{1}); + auto Xi = make_shared(element::f32, PartialShape::dynamic()); + auto Yi = make_shared(element::f32, PartialShape::dynamic()); + auto M_body = make_shared(element::f32, PartialShape::dynamic()); + auto condition_const = + std::make_shared(ngraph::element::f32, ngraph::Shape{1}, 10); + auto body_condition = std::make_shared(M_body, condition_const); + + auto trip_count = + std::make_shared(ngraph::element::i64, ngraph::Shape{1}, 10); + auto exec_condition = std::make_shared( + ngraph::element::boolean, ngraph::Shape{1}, true); + // Body + auto sum = make_shared(Xi, Yi); + auto Zo = make_shared(sum, M_body); + auto body = make_shared(OutputVector{body_condition, Zo}, + ParameterVector{current_iteration, Xi, Yi, M_body}); + + auto loop = make_shared(trip_count, exec_condition); + loop->set_function(body); + loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0}); + + loop->set_invariant_input(Xi, X); + loop->set_invariant_input(Yi, Y); + loop->set_merged_input(M_body, M, Zo); + + const auto sliced_output_axis = 4; + auto out = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, sliced_output_axis); + + auto result = make_shared(out); + try + { + auto f = make_shared(ResultVector{result}, ParameterVector{X, Y, M}); + FAIL() << "Loop was created with incorrect axis of concatenated slices output."; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), std::string("Concatenation axis must be less than sliced output rank")); + } + catch (...) + { + FAIL() << "Construction loop operator failed for unexpected reason."; + } +} + // trip_count = -1 // execution_condition = true // body_condition = true @@ -527,7 +697,7 @@ TEST(type_prop, loop_operation_infinite_loop_mode_dynamic_iter_dynamic_shapes) auto result2 = make_shared(out2); Shape out0_shape{1}; Shape out1_shape{32, 1, 10}; - PartialShape out2_shape{PartialShape::dynamic()}; + PartialShape out2_shape{32, Dimension::dynamic(), 10}; auto results = ResultVector{result0, result1, result2}; auto f = make_shared(results, ParameterVector{X, Y, M});