diff --git a/src/core/tests/models/onnx/controlflow/if_with_only_indentity_in_else_branch.prototxt b/src/core/tests/models/onnx/controlflow/if_with_only_indentity_in_else_branch.prototxt new file mode 100644 index 00000000000..3526f78c8be --- /dev/null +++ b/src/core/tests/models/onnx/controlflow/if_with_only_indentity_in_else_branch.prototxt @@ -0,0 +1,238 @@ +ir_version: 6 +graph { + node { + output: "zero" + name: "Constant_6" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 0 + } + type: TENSOR + } + } + node { + input: "input" + input: "zero" + output: "unsqueeze" + op_type: "Unsqueeze" + } + node { + output: "pads" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 10 + data_type: 7 + int64_data: 0 + int64_data: 0 + int64_data: 1 + int64_data: 0 + int64_data: 0 + int64_data: 0 + int64_data: 0 + int64_data: 1 + int64_data: 0 + int64_data: 0 + } + type: TENSOR + } + } + node { + input: "unsqueeze" + input: "pads" + output: "pad" + name: "Pad_1" + op_type: "Pad" + attribute { + name: "mode" + type: STRING + s: "constant" + } + } + node { + input: "pad" + output: "avgpool" + name: "AveragePool_2" + op_type: "AveragePool" + attribute { + name: "ceil_mode" + i: 0 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + ints: 0 + ints: 0 + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + } + node { + output: "index" + name: "Constant_3" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 1 + } + type: TENSOR + } + } + node { + input: "avgpool" + output: "avgpool_shape" + name: "Shape_4" + op_type: "Shape" + } + node { + input: "avgpool_shape" + input: "index" + output: "gather" + name: "Gather_5" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + output: "one" + name: "Constant_6" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 1 + } + type: TENSOR + } + } + node { + input: "gather" + input: "one" + output: "equal" + name: "Equal_7" + op_type: "Equal" + } + node { + input: "equal" + output: "if" + name: "If_8" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "avgpool" + input: "one" + output: "then_output" + name: "Squeeze_9" + op_type: "Squeeze" + } + name: "then" + output { + name: "then_output" + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "avgpool" + output: "else_output" + name: "Identity_10" + op_type: "Identity" + } + name: "else" + output { + name: "else_output" + } + } + type: GRAPH + } + } + node { + input: "input" + input: "if" + output: "output" + name: "Add_11" + op_type: "Add" + } + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/src/core/tests/onnx/onnx_import_controlflow.in.cpp b/src/core/tests/onnx/onnx_import_controlflow.in.cpp index 4fe87d7c174..fbb456c8f18 100644 --- a/src/core/tests/onnx/onnx_import_controlflow.in.cpp +++ b/src/core/tests/onnx/onnx_import_controlflow.in.cpp @@ -692,6 +692,32 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_if_inside_loop) { test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_if_with_only_indentity_in_else_branch) { + /* + unsq = unsqueeze(input) + padded = pad(unsq) + avgpool = avgpool(padded, kernel=[3, 1, 1]) + if_output = if (avgpool.shape[1] == 1) { + squeeze(avgpool) + } else { + identity(avgpool) + } + output = add(input, if_output) + */ + const auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/controlflow/if_with_only_indentity_in_else_branch.onnx")); + + auto test_case = test::TestCase(function, s_device); + + std::vector x(shape_size(Shape{1, 5, 2, 2})); + std::iota(x.begin(), x.end(), 0); + std::vector expected{1.333333, 3, 4.666666, 6.333333, 8, 10, 12, 14, 16, 18, + 20, 22, 24, 26, 28, 30, 25.33333, 27, 28.666667, 30.33333}; + test_case.add_input(x); + test_case.add_expected_output(expected); + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_if_dynamic_inputs) { /* if (condition) { diff --git a/src/frontends/onnx/frontend/include/onnx_import/core/node.hpp b/src/frontends/onnx/frontend/include/onnx_import/core/node.hpp index 3d193e928e2..2f249b1b912 100644 --- a/src/frontends/onnx/frontend/include/onnx_import/core/node.hpp +++ b/src/frontends/onnx/frontend/include/onnx_import/core/node.hpp @@ -60,6 +60,9 @@ public: /// \return Description of Node const std::string& get_description() const; + const std::string& input(int index) const; + std::size_t get_inputs_size() const; + const std::vector>& get_output_names() const; const std::string& output(int index) const; std::size_t get_outputs_size() const; @@ -191,4 +194,4 @@ inline std::ostream& operator<<(std::ostream& outs, const Node& node) { } // namespace onnx_import -} // namespace ngraph \ No newline at end of file +} // namespace ngraph diff --git a/src/frontends/onnx/frontend/src/core/graph.cpp b/src/frontends/onnx/frontend/src/core/graph.cpp index b043435b7da..7adbb7c3c46 100644 --- a/src/frontends/onnx/frontend/src/core/graph.cpp +++ b/src/frontends/onnx/frontend/src/core/graph.cpp @@ -186,34 +186,38 @@ std::shared_ptr Graph::convert() { return create_function(); } +OutputVector Graph::make_framework_nodes(const Node& onnx_node) { + std::shared_ptr framework_node; + if (onnx_node.has_subgraphs()) { + const auto& subgraphs = onnx_node.get_subgraphs(); + auto inputs = onnx_node.get_ng_inputs(); + std::vector> functions; + for (const auto& kv : subgraphs) { + auto& subgraph = kv.second; + functions.push_back(subgraph->decode()); + for (const auto& input : subgraph->get_inputs_from_parent()) { + const auto& name = input.get_node()->get_friendly_name(); + if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output& n) -> bool { + return name == n.get_node()->get_friendly_name(); + }) == inputs.end()) { + inputs.push_back(input); + } + } + } + framework_node = std::make_shared(onnx_node, functions, inputs); + } else { + framework_node = std::make_shared(onnx_node); + } + return framework_node->outputs(); +} + void Graph::decode_to_framework_nodes() { const float total = static_cast(m_model->get_graph().node().size()); unsigned int completed = 0u; // Process ONNX graph nodes, convert to nGraph nodes for (const auto& node_proto : m_model->get_graph().node()) { const Node node{node_proto, *this}; - std::shared_ptr framework_node; - if (node.has_subgraphs()) { - const auto& subgraphs = node.get_subgraphs(); - auto inputs = node.get_ng_inputs(); - std::vector> functions; - for (const auto& kv : subgraphs) { - auto& subgraph = kv.second; - functions.push_back(subgraph->decode()); - for (const auto& input : subgraph->get_inputs_from_parent()) { - const auto& name = input.get_node()->get_friendly_name(); - if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output& n) -> bool { - return name == n.get_node()->get_friendly_name(); - }) == inputs.end()) { - inputs.push_back(input); - } - } - } - framework_node = std::make_shared(node, functions, inputs); - } else { - framework_node = std::make_shared(node); - } - OutputVector ng_nodes{framework_node->outputs()}; + OutputVector ng_nodes{make_framework_nodes(node)}; set_friendly_names(node, ng_nodes); // Iterate over the number of outputs for given node in graph. // Some of them may be optional and trimmed. See: @@ -265,7 +269,7 @@ OutputVector Graph::get_ng_outputs() const { return results; } -OutputVector Graph::make_ng_nodes(const Node& onnx_node) const { +OutputVector Graph::make_ng_nodes(const Node& onnx_node) { const auto ng_node_factory = m_model->get_operator(onnx_node.op_type(), onnx_node.domain()); // contains outputs of nG subgraph implementing a particular ONNX node (possibly a single output of a single node) OutputVector ng_subgraph_outputs; @@ -349,80 +353,16 @@ Output Subgraph::get_ng_node_from_cache(const std::string& name) c return m_parent_graph->get_ng_node_from_cache(name); } -void Subgraph::replace_input_from_parent_scope_with_parameter(const std::string& in_name, - const Output& from_parent_node, - Input&& node_to_replace_input) { - auto new_param = std::make_shared(from_parent_node.get_element_type(), - from_parent_node.get_partial_shape()); - node_to_replace_input.replace_source_output(new_param); - m_parameter_to_parent_node_map.insert({new_param, in_name}); - m_cache->emplace_node(in_name, new_param); - m_parameters.push_back(new_param); - m_inputs_from_parent.push_back(in_name); -} - -void Subgraph::find_inputs_from_parent() { - // find all nodes on edge parent graph-subgraph - // (it means input of node from parent graph, output from subgraph) - for (const auto& node_proto : m_model->get_graph().node()) { - int input_index = 0; - for (const auto& in_name : node_proto.input()) { - if (m_parent_graph->is_ng_node_in_cache(in_name)) { - const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name); - // constants are skipped - if (!ngraph::is_type(from_parent_node.get_node_shared_ptr())) { - for (const auto& out_name : node_proto.output()) { - if (m_cache->contains(out_name)) { - auto node_to_replace_input = m_cache->get_node(out_name); - replace_input_from_parent_scope_with_parameter( - in_name, - from_parent_node, - node_to_replace_input.get_node()->input(input_index)); - } - } - } - } - ++input_index; - } - // Nodes with subgraphs (like Loop or If) can have implicit inputs (so their subgraphs depend on nodes from - // parent) Those implicit inputs are not present in `node_proto.input()` list so to get them, we need to fetch - // node's nGraph representation and then we can match those inputs with parent nodes - for (const auto& out_name : node_proto.output()) { - if (m_cache->contains(out_name)) { - auto node_to_replace_input = m_cache->get_node(out_name).get_node(); - if (!ov::is_type(node_to_replace_input) && - !ov::is_type(node_to_replace_input)) - continue; - auto inputs = node_to_replace_input->input_values(); - for (size_t i = 0; i < inputs.size(); i++) { - const auto& input = inputs.at(i); - auto input_node = input.get_node(); - if (op::is_constant(input_node)) - continue; - const auto& in_name = input_node->get_friendly_name(); - if (m_parent_graph->is_ng_node_in_cache(in_name)) { - const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name); - replace_input_from_parent_scope_with_parameter(in_name, - from_parent_node, - node_to_replace_input->input(i)); - } - } - } - } - } +OutputVector Subgraph::make_ng_nodes(const Node& onnx_node) { + replace_input_from_parent_scope_with_parameter(onnx_node); + return Graph::make_ng_nodes(onnx_node); } std::shared_ptr Subgraph::convert() { convert_to_ngraph_nodes(); - find_inputs_from_parent(); return create_function(); } -void Subgraph::decode_to_framework_nodes() { - Graph::decode_to_framework_nodes(); - find_inputs_from_parent(); -} - const std::vector> Subgraph::get_inputs_from_parent() const { OutputVector result; for (const auto& name : m_inputs_from_parent) { @@ -440,6 +380,30 @@ void Subgraph::infer_inputs_from_parent() { } } +OutputVector Subgraph::make_framework_nodes(const Node& onnx_node) { + replace_input_from_parent_scope_with_parameter(onnx_node); + return Graph::make_framework_nodes(onnx_node); +} + +void Subgraph::replace_input_from_parent_scope_with_parameter(const Node& onnx_node) { + for (std::size_t i = 0; i < onnx_node.get_inputs_size(); ++i) { + const auto& in_name = onnx_node.input(i); + if (m_parent_graph->is_ng_node_in_cache(in_name) && + std::find(m_inputs_from_parent.begin(), m_inputs_from_parent.end(), in_name) == + m_inputs_from_parent.end()) { + const auto& from_parent_node = m_parent_graph->get_ng_node_from_cache(in_name); + if (op::is_constant(from_parent_node.get_node())) + continue; + auto new_param = std::make_shared(from_parent_node.get_element_type(), + from_parent_node.get_partial_shape()); + m_parameter_to_parent_node_map.insert({new_param, in_name}); + m_cache->emplace_node(in_name, new_param); + m_parameters.push_back(new_param); + m_inputs_from_parent.push_back(in_name); + } + } +} + } // namespace onnx_import } // namespace ngraph diff --git a/src/frontends/onnx/frontend/src/core/graph.hpp b/src/frontends/onnx/frontend/src/core/graph.hpp index 8734f873122..5b668ffac8d 100644 --- a/src/frontends/onnx/frontend/src/core/graph.hpp +++ b/src/frontends/onnx/frontend/src/core/graph.hpp @@ -41,7 +41,7 @@ public: } virtual bool is_ng_node_in_cache(const std::string& name) const; virtual Output get_ng_node_from_cache(const std::string& name) const; - OutputVector make_ng_nodes(const Node& onnx_node) const; + virtual OutputVector make_ng_nodes(const Node& onnx_node); const OpsetImports& get_opset_imports() const; virtual ~Graph() = default; @@ -57,7 +57,8 @@ protected: void set_friendly_names(const Node& onnx_node, const OutputVector& ng_subgraph_outputs) const; protected: - virtual void decode_to_framework_nodes(); + virtual OutputVector make_framework_nodes(const Node& onnx_node); + void decode_to_framework_nodes(); void convert_to_ngraph_nodes(); void remove_dangling_parameters(); std::shared_ptr create_function(); @@ -98,19 +99,13 @@ public: bool is_ng_node_in_cache(const std::string& name) const override; Output get_ng_node_from_cache(const std::string& name) const override; + OutputVector make_ng_nodes(const Node& onnx_node) override; void infer_inputs_from_parent(); private: - void decode_to_framework_nodes() override; - void find_inputs_from_parent(); - /// \brief Replaces current node's input with Parameter if that input comes from parent graph scope - /// - /// \param[in] in_name input node name - /// \param[in] from_parent_node nGraph node from parent scope - /// \param[in] node_to_replace_input nGraph input node to be replaced - void replace_input_from_parent_scope_with_parameter(const std::string& in_name, - const Output& from_parent_node, - Input&& node_to_replace_input); + OutputVector make_framework_nodes(const Node& onnx_node) override; + /// \brief Checks if onnx_node has inputs from parent graph and replaces those inputs with Parameters + void replace_input_from_parent_scope_with_parameter(const Node& onnx_node); const Graph* m_parent_graph; std::vector m_inputs_from_parent; diff --git a/src/frontends/onnx/frontend/src/core/node.cpp b/src/frontends/onnx/frontend/src/core/node.cpp index 024c8f0c922..0084d7e0b90 100644 --- a/src/frontends/onnx/frontend/src/core/node.cpp +++ b/src/frontends/onnx/frontend/src/core/node.cpp @@ -50,6 +50,8 @@ public: const std::string& description() const; const std::vector>& get_output_names() const; + const std::string& input(int index) const; + std::size_t get_inputs_size() const; const std::string& output(int index) const; std::size_t get_outputs_size() const; @@ -103,6 +105,14 @@ const std::vector>& Node::Impl::get_ou return m_output_names; } +const std::string& Node::Impl::input(int index) const { + return m_node_proto->input(index); +} + +std::size_t Node::Impl::get_inputs_size() const { + return m_node_proto->input_size(); +} + const std::string& Node::Impl::output(int index) const { return m_node_proto->output(index); } @@ -110,6 +120,7 @@ const std::string& Node::Impl::output(int index) const { std::size_t Node::Impl::get_outputs_size() const { return m_output_names.size(); } + bool Node::Impl::has_attribute(const std::string& name) const { auto it = std::find_if(std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) { return attribute.get_name() == name; @@ -223,12 +234,22 @@ const std::vector>& Node::get_output_n return m_pimpl->get_output_names(); } +const std::string& Node::input(int index) const { + return m_pimpl->input(index); +} + +std::size_t Node::get_inputs_size() const { + return m_pimpl->get_inputs_size(); +} + const std::string& Node::output(int index) const { return m_pimpl->output(index); } + std::size_t Node::get_outputs_size() const { return m_pimpl->get_outputs_size(); } + bool Node::has_attribute(const std::string& name) const { return m_pimpl->has_attribute(name); }