diff --git a/ngraph/frontend/onnx_editor/include/onnx_editor/edge_mapper.hpp b/ngraph/frontend/onnx_editor/include/onnx_editor/edge_mapper.hpp new file mode 100644 index 00000000000..8b91f0c778e --- /dev/null +++ b/ngraph/frontend/onnx_editor/include/onnx_editor/edge_mapper.hpp @@ -0,0 +1,112 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "onnx_editor/editor_types.hpp" + +namespace ONNX_NAMESPACE +{ + // Forward declaration to avoid the necessity of including paths in components + // that don't directly depend on the ONNX library + class GraphProto; +} // namespace ONNX_NAMESPACE + +namespace ngraph +{ + namespace onnx_editor + { + /// \brief A class which allows specifying InputEdge and OutputEdge by user-friendly ONNX + /// names. + class EdgeMapper + { + public: + EdgeMapper() = default; + + /// \brief Creates an edge mapper based on a GraphProto object. + /// + /// \note If state of graph_proto will be changed, the information from edge mapper + /// is outdated. In such a case the update method should be called. + /// + /// \param graph_proto Reference to a GraphProto object. + EdgeMapper(const ONNX_NAMESPACE::GraphProto& graph_proto); + + /// \brief Returns the InputEdge based on a node (node name or output name) + /// and an input (input name or input index). + /// + /// \note The node name can be ambiguous (many ONNX nodes can have the same name). + /// In such a case the algorthim tries to match the given node name + /// with the input name (providing an input index is not enough). + /// If a unique edge is found, it will be returned. + /// If InputEdge cannot be determined based on parameter values an ngraph_error + /// exception will be thrown. + /// + /// \param node An EditorNode helper structure created based on a node name + /// or a node output name. + /// + /// \param input An EditorInput helper structure created based on a input name + /// or a input index. + InputEdge find_input_edge(const EditorNode& node, const EditorInput& input) const; + + /// \brief Returns an OutputEdge based on a node (node name or output name) + /// and an output (output name or output index). + /// + /// \note The node name can be ambiguous (many ONNX nodes can have the same name). + /// In such a case the algorthim will try to match the given node name + /// with the output name (providing an output index is not enough). + /// If after such operation a found edge is unique, it is returned. + /// If OutputEdge cannot be determined based on given params the ngraph_error + /// exception is thrown. + /// + /// \param node An EditorNode helper structure created based on a node name + /// or a node output name. + /// + /// \param output An EditorOutput helper structure created based on a output name + /// or a output index. + OutputEdge find_output_edge(const EditorNode& node, const EditorOutput& output) const; + + /// \brief Returns an OutputEdge based on a output name. + /// + /// \note The output name guarantees the uniqueness of the edge. + /// + /// \param output_name A node output name. + /// + OutputEdge find_output_edge(const std::string& output_name) const; + + /// \brief Returns a vector of InputEdges which consume an output of a node + /// determined by provided output name. + /// + /// \note The output name is deterministic in the ONNX standard. + /// + /// \param output_name A node output name. + /// + std::vector find_output_consumers(const std::string& output_name) const; + + /// \brief Returns true if a provided node is correct (exists in a graph) + /// and is not ambiguous (identification of an ONNX node can be ambiguous + /// if an only tensor name is provided). + /// + /// \param node An EditorNode helper structure created based on a node name + /// or a node output name. + /// + bool is_correct_and_unambiguous_node(const EditorNode& node) const; + + private: + std::vector find_node_indexes(const std::string& node_name, + const std::string& output_name) const; + std::string get_node_input_name(int node_index, int input_index) const; + std::string get_node_output_name(int node_index, int output_index) const; + + std::vector> m_node_inputs; + std::vector> m_node_outputs; + std::multimap m_node_name_to_index; + std::map m_node_output_name_to_index; + std::multimap m_output_consumers_index; + }; + } // namespace onnx_editor +} // namespace ngraph diff --git a/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp b/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp index 465890cab11..b93b568141e 100644 --- a/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp +++ b/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp @@ -108,7 +108,69 @@ namespace ngraph /// \param out_file_path A path to the file where the modified model should be dumped. void serialize(const std::string& out_file_path) const; + /// \brief Returns the InputEdge based on a node (node name or output name) + /// and an input (input name or input index). + /// + /// \note The node name can be ambiguous (many ONNX nodes can have the same name). + /// In such a case the algorthim tries to match the given node name + /// with the input name (providing an input index is not enough). + /// If a unique edge is found, it will be returned. + /// If InputEdge cannot be determined based on parameter values an ngraph_error + /// exception will be thrown. + /// + /// \param node A node helper structure created based on a node name + /// or a node output name. + /// + /// \param input An input helper structure created based on a input name + /// or a input index. + InputEdge find_input_edge(const EditorNode& node, const EditorInput& input) const; + + /// \brief Returns an OutputEdge based on a node (node name or output name) + /// and an output (output name or output index). + /// + /// \note The node name can be ambiguous (many ONNX nodes can have the same name). + /// In such a case the algorthim will try to match the given node name + /// with the output name (providing an output index is not enough). + /// If after such operation a found edge is unique, it is returned. + /// If OutputEdge cannot be determined based on given params the ngraph_error + /// exception is thrown. + /// + /// \param node A node helper structure created based on a node name + /// or a node output name. + /// + /// \param output A output helper structure created based on a output name + /// or a output index. + OutputEdge find_output_edge(const EditorNode& node, const EditorOutput& output) const; + + /// \brief Returns an OutputEdge based on a output name. + /// + /// \note The output name guarantees the uniqueness of the edge. + /// + /// \param output_name A node output name. + /// + OutputEdge find_output_edge(const std::string& output_name) const; + + /// \brief Returns a vector of InputEdges which consume an output of a node + /// determined by provided output name. + /// + /// \note The output name is deterministic in the ONNX standard. + /// + /// \param output_name A node output name. + /// + std::vector find_output_consumers(const std::string& output_name) const; + + /// \brief Returns a vector of InputEdges which consume an output of a node + /// determined by provided output name. + /// + /// \note The output name is deterministic in the ONNX standard. + /// + /// \param output_name A node output name. + /// + bool is_correct_and_unambiguous_node(const EditorNode& node) const; + private: + void update_mapper_if_needed() const; + const std::string m_model_path; struct Impl; diff --git a/ngraph/frontend/onnx_editor/include/onnx_editor/editor_types.hpp b/ngraph/frontend/onnx_editor/include/onnx_editor/editor_types.hpp index 56afa34af32..6f941d90869 100644 --- a/ngraph/frontend/onnx_editor/include/onnx_editor/editor_types.hpp +++ b/ngraph/frontend/onnx_editor/include/onnx_editor/editor_types.hpp @@ -60,5 +60,74 @@ namespace ngraph /// OutputEdge(5, "out1") /// OutputEdge(5, "out2") using OutputEdge = Edge; + + /// \brief Specifies a single node input by the name or index. + /// + /// For a node test_node, with 3 inputs: + /// + /// ----(in_A)----> +-----------+ + /// ----(in_B)----> | test_node | ----(out)----> + /// ----(in_C)----> +-----------+ + /// You can indicate in_B as EditorInput("in_B") or EditorInput(1) + struct EditorInput + { + EditorInput() = delete; + EditorInput(std::string input_name) + : m_input_name{std::move(input_name)} + { + } + EditorInput(const int input_index) + : m_input_index{input_index} + { + } + const std::string m_input_name = ""; + const int m_input_index = -1; + }; + + /// \brief Specifies a single node output by the name or index. + /// For a node test_node, with 2 outputs: + /// + /// +-----------+ ---(out1)---> + /// ----(in_A)----> | test_node | + /// +-----------+ ---(out2)---> + /// You can indicate out2 as EditorOutput("out2") or EditorOutput(1) + struct EditorOutput + { + EditorOutput() = delete; + EditorOutput(std::string output_name) + : m_output_name{std::move(output_name)} + { + } + EditorOutput(const int output_index) + : m_output_index{output_index} + { + } + const std::string m_output_name = ""; + const int m_output_index = -1; + }; + + /// \brief Specifies a single node by output name which is determinitic + /// or node name which can be ambiguous. + /// For a node test_node, with 2 outputs: + /// + /// +-----------+ ---(out1)---> + /// ----(in_A)----> | test_node | + /// +-----------+ ---(out2)---> + /// You can indicate test_node by name as EditorNode("test_node") + /// or by assigned output as EditorNode(EditorOutput("out1")) + /// or EditorNode(EditorOutput("out2")) + struct EditorNode + { + EditorNode(std::string node_name) + : m_node_name{std::move(node_name)} + { + } + EditorNode(EditorOutput output) + : m_output_name{std::move(output.m_output_name)} + { + } + const std::string m_node_name = ""; + const std::string m_output_name = ""; + }; } // namespace onnx_editor } // namespace ngraph diff --git a/ngraph/frontend/onnx_editor/src/edge_mapper.cpp b/ngraph/frontend/onnx_editor/src/edge_mapper.cpp new file mode 100644 index 00000000000..eeb244b9842 --- /dev/null +++ b/ngraph/frontend/onnx_editor/src/edge_mapper.cpp @@ -0,0 +1,243 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ngraph/except.hpp" +#include "onnx_editor/edge_mapper.hpp" + +using namespace ngraph; +using namespace ngraph::onnx_editor; + +onnx_editor::EdgeMapper::EdgeMapper(const ONNX_NAMESPACE::GraphProto& graph_proto) + : m_node_inputs(graph_proto.node().size()) + , m_node_outputs(graph_proto.node().size()) +{ + int topological_index = 0; + for (const auto& node_proto : graph_proto.node()) + { + for (const auto& out_name : node_proto.output()) + { + // node output name is unique + m_node_output_name_to_index.emplace(out_name, topological_index); + m_node_outputs[topological_index].push_back(out_name); + } + for (const auto& in_name : node_proto.input()) + { + m_node_inputs[topological_index].push_back(in_name); + m_output_consumers_index.emplace(in_name, topological_index); + } + if (!node_proto.name().empty()) + { + // node name can identify node, but it can be ambiguous + m_node_name_to_index.emplace(node_proto.name(), topological_index); + } + ++topological_index; + } +} + +std::vector onnx_editor::EdgeMapper::find_node_indexes(const std::string& node_name, + const std::string& output_name) const +{ + if (!output_name.empty()) + { + const auto& index_iter = m_node_output_name_to_index.find(output_name); + if (index_iter != std::end(m_node_output_name_to_index)) + { + return std::vector{index_iter->second}; + } + } + std::vector result; + if (!node_name.empty()) + { + const auto matched_nodes_range = m_node_name_to_index.equal_range(node_name); + std::transform(matched_nodes_range.first, + matched_nodes_range.second, + std::back_inserter(result), + [](const std::pair& iter) { return iter.second; }); + } + return result; +}; + +std::string onnx_editor::EdgeMapper::get_node_output_name(int node_index, int output_index) const +{ + if (node_index >= static_cast(m_node_outputs.size())) + { + throw ngraph_error("Node with index: " + std::to_string(node_index) + + "is out of scope outputs list"); + } + if (output_index >= static_cast(m_node_outputs[node_index].size())) + { + throw ngraph_error("Node with index: " + std::to_string(node_index) + + " has not output with index: " + std::to_string(output_index)); + } + const auto output_name = m_node_outputs[node_index][output_index]; + return output_name; +} + +std::string onnx_editor::EdgeMapper::get_node_input_name(int node_index, int input_index) const +{ + if (node_index >= static_cast(m_node_inputs.size())) + { + throw ngraph_error("Node with index: " + std::to_string(node_index) + + "is out of scope inputs list"); + } + if (input_index >= static_cast(m_node_inputs[node_index].size())) + { + throw ngraph_error("Node with index: " + std::to_string(node_index) + + " has not input with index: " + std::to_string(input_index)); + } + const auto input_name = m_node_inputs[node_index][input_index]; + return input_name; +} + +InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, + const EditorInput& in) const +{ + // identification can be both based on node name and output name + const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); + int node_index = -1; + if (node_indexes.size() == 1) + { + node_index = node_indexes[0]; + } + else if (node_indexes.empty()) + { + throw ngraph_error( + "Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + + " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + + " was not found"); + } + else if (!in.m_input_name + .empty()) // input indexes are not deterministic if a node name is ambiguous + { + // many nodes with the same name + // check if some of found index matches input name + int matched_inputs_number = 0; + for (const auto& index : node_indexes) + { + if (std::count(std::begin(m_node_inputs[index]), + std::end(m_node_inputs[index]), + in.m_input_name) > 0) + { + node_index = index; + ++matched_inputs_number; + } + } + if (matched_inputs_number == 0) + { + throw ngraph_error("Input edge described by: " + node.m_node_name + + " and input name: " + in.m_input_name + " was not found"); + } + if (matched_inputs_number > 1) + { + throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + + in.m_input_name + " are ambiguous to determine input edge"); + } + } + else + { + throw ngraph_error("Given node name: " + node.m_node_name + + " and input index: " + std::to_string(in.m_input_index) + + " are ambiguous to determine input edge"); + } + if (!in.m_input_name.empty()) + { + return InputEdge{node_index, in.m_input_name}; + } + else if (in.m_input_index != -1) // input index is set + { + const auto& input_name = get_node_input_name(node_index, in.m_input_index); + return InputEdge{node_index, input_name}; + } + else + { + throw ngraph_error("Not enough information to determine input edge"); + } +} + +OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node, + const EditorOutput& out) const +{ + // identification can be both based on node name and output name + const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); + int node_index = -1; + if (node_indexes.size() == 1) + { + node_index = node_indexes[0]; + } + else if (node_indexes.empty()) + { + throw ngraph_error( + "Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + + " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + + " was not found"); + } + else if (!out.m_output_name + .empty()) // output indexes are not deterministic if a node name is ambiguous + { + // many nodes with the same name + // check if some of found index matches output name + int matched_outputs_number = 0; + for (const auto& index : node_indexes) + { + if (std::count(std::begin(m_node_outputs[index]), + std::end(m_node_outputs[index]), + out.m_output_name) > 0) + { + node_index = index; + ++matched_outputs_number; + } + } + if (matched_outputs_number == 0) + { + throw ngraph_error("Output edge described by: " + node.m_node_name + + " and output name: " + out.m_output_name + " was not found"); + } + } + else + { + throw ngraph_error("Given node name: " + node.m_node_name + + " and output index: " + std::to_string(out.m_output_index) + + " are ambiguous to determine output edge"); + } + if (!out.m_output_name.empty()) + { + return OutputEdge{node_index, out.m_output_name}; + } + else if (out.m_output_index != -1) // output index is set + { + const auto& output_name = get_node_output_name(node_index, out.m_output_index); + return OutputEdge{node_index, output_name}; + } + else + { + throw ngraph_error("Not enough information to determine output edge"); + } +} + +OutputEdge onnx_editor::EdgeMapper::find_output_edge(const std::string& output_name) const +{ + return find_output_edge(EditorNode{EditorOutput{output_name}}, EditorOutput{output_name}); +} + +std::vector + onnx_editor::EdgeMapper::find_output_consumers(const std::string& output_name) const +{ + const auto matched_nodes_range = m_output_consumers_index.equal_range(output_name); + std::vector input_edges; + std::transform(matched_nodes_range.first, + matched_nodes_range.second, + std::back_inserter(input_edges), + [&output_name](const std::pair& iter) { + return InputEdge{iter.second, output_name}; + }); + return input_edges; +} + +bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const +{ + return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1; +} diff --git a/ngraph/frontend/onnx_editor/src/editor.cpp b/ngraph/frontend/onnx_editor/src/editor.cpp index 566659e6633..778b8261cee 100644 --- a/ngraph/frontend/onnx_editor/src/editor.cpp +++ b/ngraph/frontend/onnx_editor/src/editor.cpp @@ -10,10 +10,12 @@ #include "ngraph/log.hpp" #include "onnx_common/parser.hpp" #include "onnx_common/utils.hpp" +#include "onnx_editor/edge_mapper.hpp" #include "onnx_editor/editor.hpp" #include "onnx_import/utils/onnx_internal.hpp" using namespace ngraph; +using namespace ngraph::onnx_editor; namespace { @@ -186,6 +188,8 @@ namespace struct onnx_editor::ONNXModelEditor::Impl { ONNX_NAMESPACE::ModelProto m_model_proto; + EdgeMapper m_edge_mapper; + bool m_is_mapper_updated = false; Impl() = delete; @@ -285,6 +289,7 @@ void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vectorremove_shape_inference_info(); + m_pimpl->m_is_mapper_updated = false; } std::vector onnx_editor::ONNXModelEditor::model_inputs() const @@ -344,3 +349,45 @@ void onnx_editor::ONNXModelEditor::set_input_values( modify_initializer(*onnx_initializer, name, values, onnx_input); } } + +void onnx_editor::ONNXModelEditor::update_mapper_if_needed() const +{ + if (!m_pimpl->m_is_mapper_updated) + { + m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto.graph()); + } + m_pimpl->m_is_mapper_updated = true; +} + +InputEdge onnx_editor::ONNXModelEditor::find_input_edge(const EditorNode& node, + const EditorInput& input) const +{ + update_mapper_if_needed(); + return m_pimpl->m_edge_mapper.find_input_edge(node, input); +} + +OutputEdge onnx_editor::ONNXModelEditor::find_output_edge(const EditorNode& node, + const EditorOutput& input) const +{ + update_mapper_if_needed(); + return m_pimpl->m_edge_mapper.find_output_edge(node, input); +} + +OutputEdge onnx_editor::ONNXModelEditor::find_output_edge(const std::string& output_name) const +{ + update_mapper_if_needed(); + return m_pimpl->m_edge_mapper.find_output_edge(output_name); +} + +std::vector + onnx_editor::ONNXModelEditor::find_output_consumers(const std::string& output_name) const +{ + update_mapper_if_needed(); + return m_pimpl->m_edge_mapper.find_output_consumers(output_name); +} + +bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorNode& node) const +{ + update_mapper_if_needed(); + return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node); +} diff --git a/ngraph/test/models/onnx/model_editor/subgraph__inception_head.prototxt b/ngraph/test/models/onnx/model_editor/subgraph__inception_head.prototxt index f41a21428ab..1b024156b0e 100644 --- a/ngraph/test/models/onnx/model_editor/subgraph__inception_head.prototxt +++ b/ngraph/test/models/onnx/model_editor/subgraph__inception_head.prototxt @@ -8,7 +8,7 @@ graph { input: "conv1/7x7_s2_w_0" input: "conv1/7x7_s2_b_0" output: "conv1/7x7_s2_1" - name: "" + name: "conv1" op_type: "Conv" attribute { name: "strides" @@ -34,13 +34,13 @@ graph { node { input: "conv1/7x7_s2_1" output: "conv1/7x7_s2_2" - name: "" + name: "relu1" op_type: "Relu" } node { input: "conv1/7x7_s2_2" output: "pool1/3x3_s2_1" - name: "" + name: "maxpool1" op_type: "MaxPool" attribute { name: "strides" diff --git a/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests.prototxt b/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests.prototxt index 9e79acc9cf9..fe8972df10c 100644 --- a/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests.prototxt +++ b/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests.prototxt @@ -5,12 +5,14 @@ graph { input: "in1" output: "relu1" op_type: "Relu" + name: "relu1_name" } node { input: "relu1" input: "in2" output: "add1" op_type: "Add" + name: "add_ambiguous_name" } node { input: "in3" @@ -23,12 +25,14 @@ graph { input: "add1" output: "add2" op_type: "Add" + name: "add_ambiguous_name" } node { input: "add1" input: "conv1" output: "mul2" op_type: "Mul" + name: "" } node { input: "add2" @@ -40,6 +44,7 @@ graph { i: 1 type: INT } + name: "split_name" } node { input: "relu1" diff --git a/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests_2.prototxt b/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests_2.prototxt index 76d8587d9e8..a5fbee597db 100644 --- a/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests_2.prototxt +++ b/ngraph/test/models/onnx/model_editor/subgraph_extraction_tests_2.prototxt @@ -5,6 +5,7 @@ graph { input: "in1" output: "relu1" op_type: "Relu" + name: "add1" } node { input: "in1" @@ -20,12 +21,14 @@ graph { input: "in2" output: "relu4" op_type: "Relu" + name: "relu4_name" } node { input: "relu1" input: "relu2" output: "add1" op_type: "Add" + name: "add1_name" } node { input: "relu2" diff --git a/ngraph/test/onnx/onnx_editor.cpp b/ngraph/test/onnx/onnx_editor.cpp index 84aa9333b3c..7dbe406166f 100644 --- a/ngraph/test/onnx/onnx_editor.cpp +++ b/ngraph/test/onnx/onnx_editor.cpp @@ -21,8 +21,7 @@ NGRAPH_SUPPRESS_DEPRECATED_START using namespace ngraph; -using namespace ngraph::onnx_import; -using namespace ngraph::onnx_editor; +using namespace onnx_editor; using namespace ngraph::test; static std::string s_manifest = "${MANIFEST}"; @@ -658,6 +657,478 @@ NGRAPH_TEST(onnx_editor, subgraph__inputs_getter) EXPECT_EQ(editor.model_inputs(), (std::vector{"conv1/7x7_s2_1"})); } +// HIGHT LEVEL API TESTS +// INPUT EDGES TEST +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_output_name_and_input_name) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")}; + + const InputEdge edge = editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_2"}}, + EditorInput{"conv1/7x7_s2_1"}); + EXPECT_EQ(edge.m_node_idx, 1); + EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1"); + + const InputEdge edge2 = editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}}, + EditorInput{"data_0"}); + EXPECT_EQ(edge2.m_node_idx, 0); + EXPECT_EQ(edge2.m_tensor_name, "data_0"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_output_name_and_input_index) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")}; + + const InputEdge edge = + editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_2"}}, EditorInput{0}); + EXPECT_EQ(edge.m_node_idx, 1); + EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1"); + + const InputEdge edge2 = + editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}}, EditorInput{1}); + EXPECT_EQ(edge2.m_node_idx, 0); + EXPECT_EQ(edge2.m_tensor_name, "conv1/7x7_s2_w_0"); + + const InputEdge edge3 = + editor.find_input_edge(EditorNode{EditorOutput{"conv1/7x7_s2_1"}}, EditorInput{2}); + EXPECT_EQ(edge3.m_node_idx, 0); + EXPECT_EQ(edge3.m_tensor_name, "conv1/7x7_s2_b_0"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_name) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.prototxt")}; + + const InputEdge edge = + editor.find_input_edge(EditorNode{"relu1"}, EditorInput{"conv1/7x7_s2_1"}); + EXPECT_EQ(edge.m_node_idx, 1); + EXPECT_EQ(edge.m_tensor_name, "conv1/7x7_s2_1"); + + const InputEdge edge2 = + editor.find_input_edge(EditorNode{"conv1"}, EditorInput{"conv1/7x7_s2_w_0"}); + EXPECT_EQ(edge2.m_node_idx, 0); + EXPECT_EQ(edge2.m_tensor_name, "conv1/7x7_s2_w_0"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_index) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + const InputEdge edge = editor.find_input_edge(EditorNode{"relu1_name"}, EditorInput{0}); + EXPECT_EQ(edge.m_node_idx, 0); + EXPECT_EQ(edge.m_tensor_name, "in1"); + + const InputEdge edge2 = editor.find_input_edge(EditorNode{"split_name"}, EditorInput{0}); + EXPECT_EQ(edge2.m_node_idx, 5); + EXPECT_EQ(edge2.m_tensor_name, "add2"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + try + { + const InputEdge edge = + editor.find_input_edge(EditorNode{""}, EditorInput{"conv1/7x7_s2_1"}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE( + msg.find("Node with name: not_given and output_name: not_given was not found") != + std::string::npos); + } +} + +// OUTPUT EDGES TEST +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_output_name) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + const OutputEdge edge = + editor.find_output_edge(EditorNode{EditorOutput{"mul2"}}, EditorOutput{"mul2"}); + EXPECT_EQ(edge.m_node_idx, 4); + EXPECT_EQ(edge.m_tensor_name, "mul2"); + + const OutputEdge edge2 = + editor.find_output_edge(EditorNode{EditorOutput{"split1"}}, EditorOutput{"split2"}); + EXPECT_EQ(edge2.m_node_idx, 5); + EXPECT_EQ(edge2.m_tensor_name, "split2"); + + // simplified overload + const OutputEdge edge3 = + editor.find_output_edge("mul2"); + EXPECT_EQ(edge3.m_node_idx, 4); + EXPECT_EQ(edge3.m_tensor_name, "mul2"); + + const OutputEdge edge4 = + editor.find_output_edge("split2"); + EXPECT_EQ(edge4.m_node_idx, 5); + EXPECT_EQ(edge4.m_tensor_name, "split2"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_output_name_and_output_index) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + const OutputEdge edge = + editor.find_output_edge(EditorNode{EditorOutput{"add2"}}, EditorOutput{0}); + EXPECT_EQ(edge.m_node_idx, 3); + EXPECT_EQ(edge.m_tensor_name, "add2"); + + const OutputEdge edge2 = + editor.find_output_edge(EditorNode{EditorOutput{"split1"}}, EditorOutput{1}); + EXPECT_EQ(edge2.m_node_idx, 5); + EXPECT_EQ(edge2.m_tensor_name, "split2"); + + const OutputEdge edge3 = + editor.find_output_edge(EditorNode{EditorOutput{"split2"}}, EditorOutput{0}); + EXPECT_EQ(edge3.m_node_idx, 5); + EXPECT_EQ(edge3.m_tensor_name, "split1"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_name) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + const OutputEdge edge = + editor.find_output_edge(EditorNode{"relu1_name"}, EditorOutput{"relu1"}); + EXPECT_EQ(edge.m_node_idx, 0); + EXPECT_EQ(edge.m_tensor_name, "relu1"); + + const OutputEdge edge2 = + editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{"split2"}); + EXPECT_EQ(edge2.m_node_idx, 5); + EXPECT_EQ(edge2.m_tensor_name, "split2"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_index) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + const OutputEdge edge = editor.find_output_edge(EditorNode{"relu1_name"}, EditorOutput{0}); + EXPECT_EQ(edge.m_node_idx, 0); + EXPECT_EQ(edge.m_tensor_name, "relu1"); + + const OutputEdge edge2 = editor.find_output_edge(EditorNode{"split_name"}, EditorOutput{1}); + EXPECT_EQ(edge2.m_node_idx, 5); + EXPECT_EQ(edge2.m_tensor_name, "split2"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")}; + + const InputEdge edge = + editor.find_input_edge(EditorNode{EditorOutput{"relu4"}}, EditorInput{0}); + EXPECT_EQ(edge.m_node_idx, 3); + EXPECT_EQ(edge.m_tensor_name, "in2"); + + const OutputEdge edge2 = editor.find_output_edge(EditorNode{"relu4_name"}, EditorOutput{0}); + EXPECT_EQ(edge2.m_node_idx, 3); + EXPECT_EQ(edge2.m_tensor_name, "relu4"); + + const OutputEdge edge3 = editor.find_output_edge(EditorNode{"add1_name"}, EditorOutput{0}); + EXPECT_EQ(edge3.m_node_idx, 4); + EXPECT_EQ(edge3.m_tensor_name, "add1"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_edge_error_handling) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")}; + + // node with given output name not found + try + { + const InputEdge edge = + editor.find_input_edge(EditorNode{EditorOutput{"not_existed"}}, EditorInput{0}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE( + msg.find("Node with name: not_given and output_name: not_existed was not found") != + std::string::npos); + } + + // node with given name not found + try + { + const InputEdge edge = editor.find_input_edge(EditorNode{"not_existed"}, EditorInput{0}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE( + msg.find("Node with name: not_existed and output_name: not_given was not found") != + std::string::npos); + } + + // input index out of scope + try + { + const InputEdge edge = editor.find_input_edge(EditorNode{"relu4_name"}, EditorInput{1}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Node with index: 3 has not input with index: 1") != + std::string::npos); + } + + // output index out of scope + try + { + const OutputEdge edge = + editor.find_output_edge(EditorNode{"relu4_name"}, EditorOutput{1}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Node with index: 3 has not output with index: 1") != + std::string::npos); + } +} + +// Nodes with ambiguous node names tests +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_but_matched_input) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"in2"}); + EXPECT_EQ(edge.m_node_idx, 1); + EXPECT_EQ(edge.m_tensor_name, "in2"); + + const InputEdge edge2 = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"add1"}); + EXPECT_EQ(edge2.m_node_idx, 3); + EXPECT_EQ(edge2.m_tensor_name, "add1"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_and_not_matched_input) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + try + { + const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"in3"}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Input edge described by: add_ambiguous_name and input name: in3 was not found") != + std::string::npos); + } + + try + { + const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{"relu1"}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and input name: relu1 are ambiguous to determine input edge") != + std::string::npos); + } +} + +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_ambiguous_node_name_and_input_index) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + try + { + const InputEdge edge = editor.find_input_edge(EditorNode{"add_ambiguous_name"}, EditorInput{0}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and input index: 0 are ambiguous to determine input edge") != + std::string::npos); + } +} + +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_but_matched_output) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"add1"}); + EXPECT_EQ(edge.m_node_idx, 1); + EXPECT_EQ(edge.m_tensor_name, "add1"); + + const OutputEdge edge2 = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"add2"}); + EXPECT_EQ(edge2.m_node_idx, 3); + EXPECT_EQ(edge2.m_tensor_name, "add2"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_the_same_node_name_and_output_name) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.prototxt")}; + + const OutputEdge edge = editor.find_output_edge(EditorNode{"add1"}, EditorOutput{0}); + EXPECT_EQ(edge.m_node_idx, 0); + EXPECT_EQ(edge.m_tensor_name, "relu1"); + + const OutputEdge edge2 = editor.find_output_edge(EditorNode{EditorOutput{"add1"}}, EditorOutput{0}); + EXPECT_EQ(edge2.m_node_idx, 4); + EXPECT_EQ(edge2.m_tensor_name, "add1"); +} + +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_and_not_matched_output) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + try + { + const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{"split2"}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Output edge described by: add_ambiguous_name and output name: split2 was not found") != + std::string::npos); + } +} + +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_ambiguous_node_name_and_output_index) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + try + { + const OutputEdge edge = editor.find_output_edge(EditorNode{"add_ambiguous_name"}, EditorOutput{0}); + } + catch (const std::exception& e) + { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Given node name: add_ambiguous_name and output index: 0 are ambiguous to determine output edge") != + std::string::npos); + } +} + +NGRAPH_TEST(onnx_editor, editor_api_use_edge_mapper_with_graph_cutter) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + // InputEdge{1, "in2"} + const auto input_edge_1 = editor.find_input_edge( + EditorNode(EditorOutput("add1")), EditorInput(1)); + // InputEdge{2, "in3"} + const auto input_edge_2 = editor.find_input_edge( + EditorNode(EditorOutput("conv1")), EditorInput(0)); + + + const auto output_edge = editor.find_output_edge( + EditorNode(EditorOutput("mul2")), EditorOutput(0)); + // OutputEdge{4, "mul2"} + editor.cut_graph_fragment({input_edge_1, input_edge_2}, {output_edge}); + + const auto ref_model = + file_util::path_join(SERIALIZED_ZOO, + "onnx/model_editor/reference/" + "subgraph__existing_inputs_and_outputs_based_extraction.prototxt"); + + const auto result = compare_onnx_models(editor.model_string(), ref_model); + + EXPECT_TRUE(result.is_ok) << result.error_message; + + // check if mapper was updated after the model changed + const auto input_edge_4 = editor.find_input_edge( + EditorNode(EditorOutput("relu1")), EditorInput(0)); + EXPECT_EQ(input_edge_4.m_node_idx, 0); + EXPECT_EQ(input_edge_4.m_tensor_name, "in1"); + + const auto input_edge_5 = editor.find_input_edge( + EditorNode(EditorOutput("add1")), EditorInput(1)); + EXPECT_EQ(input_edge_5.m_node_idx, 1); + EXPECT_EQ(input_edge_5.m_tensor_name, "in2"); + + const auto output_edge_3 = editor.find_output_edge("mul2"); + EXPECT_EQ(output_edge_3.m_node_idx, 3); + EXPECT_EQ(output_edge_3.m_tensor_name, "mul2"); +} + +NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + std::vector output_consumers = editor.find_output_consumers("relu1"); + EXPECT_EQ(output_consumers.size(), 3); + EXPECT_EQ(output_consumers[0].m_node_idx, 1); + EXPECT_EQ(output_consumers[0].m_tensor_name, "relu1"); + EXPECT_EQ(output_consumers[1].m_node_idx, 3); + EXPECT_EQ(output_consumers[1].m_tensor_name, "relu1"); + EXPECT_EQ(output_consumers[2].m_node_idx, 6); + EXPECT_EQ(output_consumers[2].m_tensor_name, "relu1"); + + output_consumers = editor.find_output_consumers("add1"); + EXPECT_EQ(output_consumers.size(), 2); + EXPECT_EQ(output_consumers[0].m_node_idx, 3); + EXPECT_EQ(output_consumers[0].m_tensor_name, "add1"); + EXPECT_EQ(output_consumers[1].m_node_idx, 4); + EXPECT_EQ(output_consumers[1].m_tensor_name, "add1"); + + output_consumers = editor.find_output_consumers("in3"); + EXPECT_EQ(output_consumers.size(), 1); + EXPECT_EQ(output_consumers[0].m_node_idx, 2); + EXPECT_EQ(output_consumers[0].m_tensor_name, "in3"); +} + +NGRAPH_TEST(onnx_editor, editor_api_find_output_consumers_empty_result) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + const std::vector output_consumers = editor.find_output_consumers("not_existed"); + EXPECT_EQ(output_consumers.size(), 0); +} + +NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) +{ + ONNXModelEditor editor{file_util::path_join( + SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.prototxt")}; + + bool is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"relu1"}}); + EXPECT_EQ(is_correct_node, true); + + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"mul2"}}); + EXPECT_EQ(is_correct_node, true); + + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"split2"}}); + EXPECT_EQ(is_correct_node, true); + + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"}); + EXPECT_EQ(is_correct_node, true); + + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}}); + EXPECT_EQ(is_correct_node, false); + + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"add_ambiguous_name"}); + EXPECT_EQ(is_correct_node, false); + + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"not_exist"}); + EXPECT_EQ(is_correct_node, false); +} + using TestEngine = test::INTERPRETER_Engine; NGRAPH_TEST(onnx_editor, values__append_one_initializer)