diff --git a/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp b/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp index c81a3a01f58..033bc01a6cb 100644 --- a/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp +++ b/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp @@ -103,7 +103,7 @@ public: /// /// \return A vector with all operation node references that consumes data from this /// place - virtual std::vector get_consuming_operations(const std::string& outputPortName) const; + virtual std::vector get_consuming_operations(const std::string& outputName) const; /// \brief Returns references to all operation nodes that consume data from this place /// for specified output port @@ -127,16 +127,14 @@ public: /// \return A tensor place which hold the resulting value for this place virtual Ptr get_target_tensor() const; - /// \brief Returns a tensor place that gets data from this place; applicable for - /// operations, output ports and output edges which have only one output port + /// \brief Returns a tensor place that gets data from this place; applicable for operations /// /// \param outputName Name of output port group /// /// \return A tensor place which hold the resulting value for this place virtual Ptr get_target_tensor(const std::string& outputName) const; - /// \brief Returns a tensor place that gets data from this place; applicable for - /// operations, output ports and output edges which have only one output port + /// \brief Returns a tensor place that gets data from this place; applicable for operations /// /// \param outputName Name of output port group, each group can have multiple ports /// @@ -146,8 +144,7 @@ public: /// \return A tensor place which hold the resulting value for this place virtual Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const; - /// \brief Returns a tensor place that gets data from this place; applicable for - /// operations, output ports and output edges + /// \brief Returns a tensor place that gets data from this place; applicable for operations /// /// \param output_port_index Output port index if the current place is an operation node /// and has multiple output ports @@ -161,24 +158,21 @@ public: /// \return A tensor place which supplies data for this place virtual Ptr get_source_tensor() const; - /// \brief Returns a tensor place that supplies data for this place; applicable for - /// operations, input ports and input edges + /// \brief Returns a tensor place that supplies data for this place; applicable for operations /// /// \param input_port_index Input port index for operational nodes. /// /// \return A tensor place which supplies data for this place virtual Ptr get_source_tensor(int input_port_index) const; - /// \brief Returns a tensor place that supplies data for this place; applicable for - /// operations, input ports and input edges + /// \brief Returns a tensor place that supplies data for this place; applicable for operations /// /// \param inputName Name of input port group /// /// \return A tensor place which supplies data for this place virtual Ptr get_source_tensor(const std::string& inputName) const; - /// \brief Returns a tensor place that supplies data for this place; applicable for - /// operations, input ports and input edges + /// \brief Returns a tensor place that supplies data for this place; applicable for operations /// /// \param inputName If a given place is itself an operation node, this specifies name /// of output port group, each group can have multiple ports diff --git a/ngraph/frontend/onnx/frontend/src/edge_mapper.cpp b/ngraph/frontend/onnx/frontend/src/edge_mapper.cpp index 4c6132dd78e..4602a2db623 100644 --- a/ngraph/frontend/onnx/frontend/src/edge_mapper.cpp +++ b/ngraph/frontend/onnx/frontend/src/edge_mapper.cpp @@ -95,37 +95,41 @@ std::vector onnx_editor::EdgeMapper::get_node_input_indexes(int node_index, } InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const EditorInput& in) const { - // identification can be both based on node name and output name - const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); - int node_index = -1; - if (node_indexes.size() == 1) { - node_index = node_indexes[0]; - } else if (node_indexes.empty()) { - throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + - " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + - " was not found"); - } else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous - { - // many nodes with the same name - // check if some of found index matches input name - int matched_inputs_number = 0; - for (const auto& index : node_indexes) { - if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) { - node_index = index; - ++matched_inputs_number; - } - } - if (matched_inputs_number == 0) { - throw ngraph_error("Input edge described by: " + node.m_node_name + " and input name: " + in.m_input_name + + int node_index = node.m_node_index; + if (node_index == -1) { // the node index is not provided + // identification can be both based on node name and output name (if the node index is not provided) + const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); + if (node_indexes.size() == 1) { + node_index = node_indexes[0]; + } else if (node_indexes.empty()) { + throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + + " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + " was not found"); + } else if (!in.m_input_name.empty()) // input indexes are not deterministic if a node name is ambiguous + { + // many nodes with the same name + // check if some of found index matches input name + int matched_inputs_number = 0; + for (const auto& index : node_indexes) { + if (std::count(std::begin(m_node_inputs[index]), std::end(m_node_inputs[index]), in.m_input_name) > 0) { + node_index = index; + ++matched_inputs_number; + } + } + if (matched_inputs_number == 0) { + throw ngraph_error("Input edge described by: " + node.m_node_name + + " and input name: " + in.m_input_name + " was not found"); + } + if (matched_inputs_number > 1) { + throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name + + " are ambiguous to determine input edge"); + } + } else { + throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " + + std::to_string(in.m_input_index) + " are ambiguous to determine input edge"); } - if (matched_inputs_number > 1) { - throw ngraph_error("Given node name: " + node.m_node_name + " and input name: " + in.m_input_name + - " are ambiguous to determine input edge"); - } - } else { - throw ngraph_error("Given node name: " + node.m_node_name + " and input index: " + - std::to_string(in.m_input_index) + " are ambiguous to determine input edge"); + } else { // the node index is provided + check_node_index(node_index); } if (in.m_input_index != -1) // input index is set { @@ -146,33 +150,38 @@ InputEdge onnx_editor::EdgeMapper::find_input_edge(const EditorNode& node, const } OutputEdge onnx_editor::EdgeMapper::find_output_edge(const EditorNode& node, const EditorOutput& out) const { - // identification can be both based on node name and output name - const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); - int node_index = -1; - if (node_indexes.size() == 1) { - node_index = node_indexes[0]; - } else if (node_indexes.empty()) { - throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + - " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + - " was not found"); - } else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous - { - // many nodes with the same name - // check if some of found index matches output name - int matched_outputs_number = 0; - for (const auto& index : node_indexes) { - if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) > 0) { - node_index = index; - ++matched_outputs_number; + int node_index = node_index = node.m_node_index; + if (node_index == -1) { // the node index is not provided + // identification can be both based on node name and output name (if the node index is not provided) + const auto& node_indexes = find_node_indexes(node.m_node_name, node.m_output_name); + if (node_indexes.size() == 1) { + node_index = node_indexes[0]; + } else if (node_indexes.empty()) { + throw ngraph_error("Node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + + " and output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + + " was not found"); + } else if (!out.m_output_name.empty()) // output indexes are not deterministic if a node name is ambiguous + { + // many nodes with the same name + // check if some of found index matches output name + int matched_outputs_number = 0; + for (const auto& index : node_indexes) { + if (std::count(std::begin(m_node_outputs[index]), std::end(m_node_outputs[index]), out.m_output_name) > + 0) { + node_index = index; + ++matched_outputs_number; + } } + if (matched_outputs_number == 0) { + throw ngraph_error("Output edge described by: " + node.m_node_name + + " and output name: " + out.m_output_name + " was not found"); + } + } else { + throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " + + std::to_string(out.m_output_index) + " are ambiguous to determine output edge"); } - if (matched_outputs_number == 0) { - throw ngraph_error("Output edge described by: " + node.m_node_name + - " and output name: " + out.m_output_name + " was not found"); - } - } else { - throw ngraph_error("Given node name: " + node.m_node_name + " and output index: " + - std::to_string(out.m_output_index) + " are ambiguous to determine output edge"); + } else { // the node index is provided + check_node_index(node_index); } if (out.m_output_index != -1) // output index is set { @@ -197,16 +206,45 @@ std::vector onnx_editor::EdgeMapper::find_output_consumers(const std: const auto node_idx = it->second; const auto port_indexes = get_node_input_indexes(node_idx, output_name); for (const auto& idx : port_indexes) { - input_edges.push_back(InputEdge{node_idx, idx}); + const auto consumer_edge = InputEdge{node_idx, idx}; + if (std::find_if(std::begin(input_edges), std::end(input_edges), [&consumer_edge](const InputEdge& edge) { + return edge.m_node_idx == consumer_edge.m_node_idx && edge.m_port_idx == consumer_edge.m_port_idx; + }) == std::end(input_edges)) { + // only unique + input_edges.push_back(consumer_edge); + } } } return input_edges; } bool onnx_editor::EdgeMapper::is_correct_and_unambiguous_node(const EditorNode& node) const { + if (node.m_node_index >= 0 && node.m_node_index < static_cast(m_node_inputs.size())) { + return true; + } return find_node_indexes(node.m_node_name, node.m_output_name).size() == 1; } +namespace { +void check_node(bool condition, const EditorNode& node) { + NGRAPH_CHECK(condition, + "The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + + ", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + + ", node_index: " + (node.m_node_index == -1 ? "not_given" : std::to_string(node.m_node_index)) + + " is ambiguous"); +} +} // namespace + +int onnx_editor::EdgeMapper::get_node_index(const EditorNode& node) const { + if (node.m_node_index != -1) { // the node index provided + check_node_index(node.m_node_index); + return node.m_node_index; + } + const auto indexes = find_node_indexes(node.m_node_name, node.m_output_name); + check_node(indexes.size() == 1, node); + return indexes[0]; +} + bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) const { if (m_node_output_name_to_index.find(name) != std::end(m_node_output_name_to_index)) { return true; @@ -218,20 +256,25 @@ bool onnx_editor::EdgeMapper::is_correct_tensor_name(const std::string& name) co } std::vector onnx_editor::EdgeMapper::get_input_ports(const EditorNode& node) const { - NGRAPH_CHECK(is_correct_and_unambiguous_node(node), - "The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + - ", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + - " is ambiguous"); - const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0]; + check_node(is_correct_and_unambiguous_node(node), node); + auto node_index = node.m_node_index; + if (node_index == -1) { // the node index is provided + node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0]; + } else { + check_node_index(node_index); + } return m_node_inputs[node_index]; } std::vector onnx_editor::EdgeMapper::get_output_ports(const EditorNode& node) const { - NGRAPH_CHECK(is_correct_and_unambiguous_node(node), - "The node with name: " + (node.m_node_name.empty() ? "not_given" : node.m_node_name) + - ", output_name: " + (node.m_output_name.empty() ? "not_given" : node.m_output_name) + - " is ambiguous"); - const auto node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0]; + check_node(is_correct_and_unambiguous_node(node), node); + auto node_index = node.m_node_index; + if (node_index == -1) // the node index is provided + { + node_index = find_node_indexes(node.m_node_name, node.m_output_name)[0]; + } else { + check_node_index(node_index); + } return m_node_outputs[node_index]; } @@ -250,3 +293,8 @@ std::string onnx_editor::EdgeMapper::get_target_tensor_name(const OutputEdge& ed } return ""; } + +void onnx_editor::EdgeMapper::check_node_index(int node_index) const { + NGRAPH_CHECK(node_index >= 0 && node_index < static_cast(m_node_inputs.size()), + "Provided node index: " + std::to_string(node_index) + " is out of scope"); +} diff --git a/ngraph/frontend/onnx/frontend/src/edge_mapper.hpp b/ngraph/frontend/onnx/frontend/src/edge_mapper.hpp index fb16d147ade..d9fe93fcaf0 100644 --- a/ngraph/frontend/onnx/frontend/src/edge_mapper.hpp +++ b/ngraph/frontend/onnx/frontend/src/edge_mapper.hpp @@ -92,6 +92,16 @@ public: /// bool is_correct_and_unambiguous_node(const EditorNode& node) const; + /// \brief Returns index (position) of provided node in the graph + /// in topological order. + /// + /// \param node An EditorNode helper structure created based on a node name + /// or a node output name. + /// + /// \note The exception will be thrown if the provided node is ambiguous. + /// + int get_node_index(const EditorNode& node) const; + /// \brief Returns true if a provided tensor name is correct (exists in a graph). /// /// \param name The name of tensor in a graph. @@ -130,6 +140,7 @@ private: // note: a single node can have more than one inputs with the same name std::vector get_node_input_indexes(int node_index, const std::string& input_name) const; int get_node_output_idx(int node_index, const std::string& output_name) const; + void check_node_index(int node_index) const; std::vector> m_node_inputs; std::vector> m_node_outputs; diff --git a/ngraph/frontend/onnx/frontend/src/editor.cpp b/ngraph/frontend/onnx/frontend/src/editor.cpp index 6551543d5f1..c2741e68f46 100644 --- a/ngraph/frontend/onnx/frontend/src/editor.cpp +++ b/ngraph/frontend/onnx/frontend/src/editor.cpp @@ -444,6 +444,11 @@ bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorN return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node); } +int onnx_editor::ONNXModelEditor::get_node_index(const EditorNode& node) const { + update_mapper_if_needed(); + return m_pimpl->m_edge_mapper.get_node_index(node); +} + bool onnx_editor::ONNXModelEditor::is_correct_tensor_name(const std::string& name) const { update_mapper_if_needed(); return m_pimpl->m_edge_mapper.is_correct_tensor_name(name); diff --git a/ngraph/frontend/onnx/frontend/src/editor.hpp b/ngraph/frontend/onnx/frontend/src/editor.hpp index 66d4d7e5ff0..bfda345efe3 100644 --- a/ngraph/frontend/onnx/frontend/src/editor.hpp +++ b/ngraph/frontend/onnx/frontend/src/editor.hpp @@ -198,6 +198,16 @@ public: /// bool is_correct_and_unambiguous_node(const EditorNode& node) const; + /// \brief Returns index (position) of provided node in the graph + /// in topological order. + /// + /// \param node An EditorNode helper structure created based on a node name + /// or a node output name. + /// + /// \note The exception will be thrown if the provided node is ambiguous. + /// + int get_node_index(const EditorNode& node) const; + /// \brief Returns true if a provided tensor name is correct (exists in a graph). /// /// \param name The name of tensor in a graph. diff --git a/ngraph/frontend/onnx/frontend/src/editor_types.hpp b/ngraph/frontend/onnx/frontend/src/editor_types.hpp index 0d3db82f588..febbbb01449 100644 --- a/ngraph/frontend/onnx/frontend/src/editor_types.hpp +++ b/ngraph/frontend/onnx/frontend/src/editor_types.hpp @@ -114,11 +114,14 @@ struct EditorOutput { /// You can indicate test_node by name as EditorNode("test_node") /// or by assigned output as EditorNode(EditorOutput("out1")) /// or EditorNode(EditorOutput("out2")) +/// or you can determine the node by postition of a node in an ONNX graph (in topological order). struct EditorNode { EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {} EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {} + EditorNode(const int node_index) : m_node_index{node_index} {} const std::string m_node_name = ""; const std::string m_output_name = ""; + const int m_node_index = -1; }; } // namespace onnx_editor } // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/place.cpp b/ngraph/frontend/onnx/frontend/src/place.cpp index 915b6a401f5..15a6bb74172 100644 --- a/ngraph/frontend/onnx/frontend/src/place.cpp +++ b/ngraph/frontend/onnx/frontend/src/place.cpp @@ -47,6 +47,18 @@ Place::Ptr PlaceInputEdgeONNX::get_source_tensor() const { return std::make_shared(m_editor->get_source_tensor_name(m_edge), m_editor); } +std::vector PlaceInputEdgeONNX::get_consuming_operations() const { + return {std::make_shared(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor)}; +} + +Place::Ptr PlaceInputEdgeONNX::get_producing_operation() const { + return get_source_tensor()->get_producing_operation(); +} + +Place::Ptr PlaceInputEdgeONNX::get_producing_port() const { + return get_source_tensor()->get_producing_port(); +} + PlaceOutputEdgeONNX::PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge, std::shared_ptr editor) : m_edge{edge}, @@ -85,6 +97,18 @@ Place::Ptr PlaceOutputEdgeONNX::get_target_tensor() const { return std::make_shared(m_editor->get_target_tensor_name(m_edge), m_editor); } +std::vector PlaceOutputEdgeONNX::get_consuming_ports() const { + return get_target_tensor()->get_consuming_ports(); +} + +Place::Ptr PlaceOutputEdgeONNX::get_producing_operation() const { + return std::make_shared(onnx_editor::EditorNode{m_edge.m_node_idx}, m_editor); +} + +std::vector PlaceOutputEdgeONNX::get_consuming_operations() const { + return get_target_tensor()->get_consuming_operations(); +} + PlaceTensorONNX::PlaceTensorONNX(const std::string& name, std::shared_ptr editor) : m_name{name}, m_editor{std::move(editor)} {} @@ -112,10 +136,8 @@ std::vector PlaceTensorONNX::get_consuming_ports() const { return ret; } -Place::Ptr PlaceTensorONNX::get_input_port(int input_port_index) const { - return std::make_shared( - m_editor->find_input_edge(onnx_editor::EditorOutput(m_name), onnx_editor::EditorInput(input_port_index)), - m_editor); +Place::Ptr PlaceTensorONNX::get_producing_operation() const { + return get_producing_port()->get_producing_operation(); } bool PlaceTensorONNX::is_input() const { @@ -146,6 +168,19 @@ bool PlaceTensorONNX::is_equal_data(Place::Ptr another) const { eq_to_consuming_port(another); } +std::vector PlaceTensorONNX::get_consuming_operations() const { + std::vector consuming_ports = get_consuming_ports(); + std::vector consuming_ops; + std::transform(std::begin(consuming_ports), + std::end(consuming_ports), + std::back_inserter(consuming_ops), + [](const Place::Ptr& place) { + return place->get_consuming_operations().at(0); + }); + + return consuming_ops; +} + PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr editor) : m_node{node}, m_editor{std::move(editor)} {} @@ -158,6 +193,10 @@ std::vector PlaceOpONNX::get_names() const { return {m_node.m_node_name}; } +onnx_editor::EditorNode PlaceOpONNX::get_editor_node() const { + return m_node; +} + Place::Ptr PlaceOpONNX::get_output_port() const { if (m_editor->get_output_ports(m_node).size() == 1) { return get_output_port(0); @@ -209,3 +248,133 @@ Place::Ptr PlaceOpONNX::get_input_port(const std::string& input_name) const { } return nullptr; } + +std::vector PlaceOpONNX::get_consuming_ports() const { + std::vector consuming_ports; + const auto out_ports_size = m_editor->get_output_ports(m_node).size(); + for (int out_idx = 0; out_idx < out_ports_size; ++out_idx) { + auto consuming_ops_out = get_output_port(out_idx)->get_consuming_ports(); + consuming_ports.insert(consuming_ports.end(), consuming_ops_out.begin(), consuming_ops_out.end()); + } + return consuming_ports; +} + +namespace { +std::vector get_consuming_ops(std::vector input_ports) { + std::vector consuming_ops; + std::transform(std::begin(input_ports), + std::end(input_ports), + std::back_inserter(consuming_ops), + [](const Place::Ptr place) { + return place->get_consuming_operations().at(0); + }); + + return consuming_ops; +} +} // namespace + +std::vector PlaceOpONNX::get_consuming_operations() const { + std::vector consuming_ports = get_consuming_ports(); + return get_consuming_ops(consuming_ports); +} + +std::vector PlaceOpONNX::get_consuming_operations(int output_port_index) const { + std::vector consuming_ports = get_output_port(output_port_index)->get_consuming_ports(); + return get_consuming_ops(consuming_ports); +} + +std::vector PlaceOpONNX::get_consuming_operations(const std::string& output_port_name) const { + std::vector consuming_ports = get_output_port(output_port_name)->get_consuming_ports(); + return get_consuming_ops(consuming_ports); +} + +Place::Ptr PlaceOpONNX::get_producing_operation() const { + const auto input_port = get_input_port(); + if (input_port != nullptr) { + return input_port->get_producing_operation(); + } + return nullptr; +} + +Place::Ptr PlaceOpONNX::get_producing_operation(int input_port_index) const { + const auto input_port = get_input_port(input_port_index); + if (input_port != nullptr) { + return input_port->get_producing_operation(); + } + return nullptr; +} + +Place::Ptr PlaceOpONNX::get_producing_operation(const std::string& input_port_name) const { + const auto input_port = get_input_port(input_port_name); + if (input_port != nullptr) { + return input_port->get_producing_operation(); + } + return nullptr; +} + +bool PlaceOpONNX::is_equal(Place::Ptr another) const { + if (const auto place_op = std::dynamic_pointer_cast(another)) { + const auto& another_node = place_op->get_editor_node(); + if (m_editor->is_correct_and_unambiguous_node(m_node) || + m_editor->is_correct_and_unambiguous_node(another_node)) { + return m_editor->get_node_index(m_node) == m_editor->get_node_index(another_node); + } + } + return false; +} + +Place::Ptr PlaceOpONNX::get_target_tensor() const { + const auto output_port = get_output_port(); + if (output_port != nullptr) { + return output_port->get_target_tensor(); + } + return nullptr; +} + +Place::Ptr PlaceOpONNX::get_target_tensor(int output_port_index) const { + const auto output_port = get_output_port(output_port_index); + if (output_port != nullptr) { + return output_port->get_target_tensor(); + } + return nullptr; +} + +Place::Ptr PlaceOpONNX::get_target_tensor(const std::string& output_name) const { + const auto output_port = get_output_port(output_name); + if (output_port != nullptr) { + return output_port->get_target_tensor(); + } + return nullptr; +} + +Place::Ptr PlaceOpONNX::get_source_tensor() const { + const auto input_port = get_input_port(); + if (input_port != nullptr) { + return input_port->get_source_tensor(); + } + return nullptr; +} + +Place::Ptr PlaceOpONNX::get_source_tensor(int input_port_index) const { + const auto input_port = get_input_port(input_port_index); + if (input_port != nullptr) { + return input_port->get_source_tensor(); + } + return nullptr; +} + +Place::Ptr PlaceOpONNX::get_source_tensor(const std::string& input_name) const { + const auto input_port = get_input_port(input_name); + if (input_port != nullptr) { + return input_port->get_source_tensor(); + } + return nullptr; +} + +bool PlaceOpONNX::is_input() const { + return false; +} + +bool PlaceOpONNX::is_output() const { + return false; +} diff --git a/ngraph/frontend/onnx/frontend/src/place.hpp b/ngraph/frontend/onnx/frontend/src/place.hpp index efc7e7a8155..82d961998a4 100644 --- a/ngraph/frontend/onnx/frontend/src/place.hpp +++ b/ngraph/frontend/onnx/frontend/src/place.hpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace ngraph { namespace frontend { @@ -15,17 +16,18 @@ public: PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge, std::shared_ptr editor); PlaceInputEdgeONNX(onnx_editor::InputEdge&& edge, std::shared_ptr editor); + // internal usage onnx_editor::InputEdge get_input_edge() const; + // external usage bool is_input() const override; - bool is_output() const override; - bool is_equal(Place::Ptr another) const override; - bool is_equal_data(Place::Ptr another) const override; - Place::Ptr get_source_tensor() const override; + std::vector get_consuming_operations() const override; + Place::Ptr get_producing_operation() const override; + Place::Ptr get_producing_port() const override; private: onnx_editor::InputEdge m_edge; @@ -37,17 +39,18 @@ public: PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge, std::shared_ptr editor); PlaceOutputEdgeONNX(onnx_editor::OutputEdge&& edge, std::shared_ptr editor); + // internal usage onnx_editor::OutputEdge get_output_edge() const; + // external usage bool is_input() const override; - bool is_output() const override; - bool is_equal(Place::Ptr another) const override; - bool is_equal_data(Place::Ptr another) const override; - Place::Ptr get_target_tensor() const override; + std::vector get_consuming_ports() const override; + Place::Ptr get_producing_operation() const override; + std::vector get_consuming_operations() const override; private: onnx_editor::OutputEdge m_edge; @@ -59,21 +62,16 @@ public: PlaceTensorONNX(const std::string& name, std::shared_ptr editor); PlaceTensorONNX(std::string&& name, std::shared_ptr editor); + // external usage std::vector get_names() const override; - Place::Ptr get_producing_port() const override; - std::vector get_consuming_ports() const override; - - Ptr get_input_port(int input_port_index) const override; - + Place::Ptr get_producing_operation() const override; bool is_input() const override; - bool is_output() const override; - bool is_equal(Place::Ptr another) const override; - bool is_equal_data(Place::Ptr another) const override; + std::vector get_consuming_operations() const override; private: std::string m_name; @@ -86,6 +84,10 @@ public: PlaceOpONNX(onnx_editor::EditorNode&& node, std::shared_ptr editor); std::vector get_names() const override; + // internal usage + onnx_editor::EditorNode get_editor_node() const; + + // external usage Place::Ptr get_output_port() const override; Place::Ptr get_output_port(int output_port_index) const override; Place::Ptr get_output_port(const std::string& output_port_name) const override; @@ -94,6 +96,27 @@ public: Place::Ptr get_input_port(int input_port_index) const override; Place::Ptr get_input_port(const std::string& input_name) const override; + std::vector get_consuming_ports() const override; + std::vector get_consuming_operations() const override; + std::vector get_consuming_operations(int output_port_index) const override; + std::vector get_consuming_operations(const std::string& output_port_name) const override; + + Place::Ptr get_producing_operation() const override; + Place::Ptr get_producing_operation(int input_port_index) const override; + Place::Ptr get_producing_operation(const std::string& input_port_name) const override; + + Place::Place::Ptr get_target_tensor() const override; + Place::Ptr get_target_tensor(int output_port_index) const override; + Place::Ptr get_target_tensor(const std::string& output_name) const override; + + Place::Place::Ptr get_source_tensor() const override; + Place::Ptr get_source_tensor(int input_port_index) const override; + Place::Ptr get_source_tensor(const std::string& input_name) const override; + + bool is_equal(Place::Ptr another) const override; + bool is_input() const override; + bool is_output() const override; + private: onnx_editor::EditorNode m_node; std::shared_ptr m_editor; diff --git a/ngraph/test/onnx/onnx_editor.cpp b/ngraph/test/onnx/onnx_editor.cpp index 96034029677..21c7be215b6 100644 --- a/ngraph/test/onnx/onnx_editor.cpp +++ b/ngraph/test/onnx/onnx_editor.cpp @@ -693,6 +693,26 @@ NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_name_and_input_ind EXPECT_EQ(edge2.m_new_input_name, "custom_input_name_2"); } +NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_by_node_index) { + ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")}; + + const InputEdge edge = editor.find_input_edge(EditorNode{0}, EditorInput{0, "custom_input_name_1"}); + EXPECT_EQ(edge.m_node_idx, 0); + EXPECT_EQ(edge.m_port_idx, 0); + EXPECT_EQ(edge.m_new_input_name, "custom_input_name_1"); + + const InputEdge edge2 = editor.find_input_edge(EditorNode{5}, EditorInput{0}); + EXPECT_EQ(edge2.m_node_idx, 5); + EXPECT_EQ(edge2.m_port_idx, 0); + + try { + editor.find_input_edge(EditorNode{99}, EditorInput{"conv1/7x7_s2_1"}); + } catch (const std::exception& e) { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos); + } +} + NGRAPH_TEST(onnx_editor, editor_api_select_input_edge_empty_node_name) { ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")}; @@ -767,6 +787,21 @@ NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_name_and_output_i EXPECT_EQ(edge2.m_port_idx, 1); } +NGRAPH_TEST(onnx_editor, editor_api_select_output_edge_by_node_index) { + ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")}; + + const OutputEdge edge = editor.find_output_edge(EditorNode{5}, EditorOutput{1}); + EXPECT_EQ(edge.m_node_idx, 5); + EXPECT_EQ(edge.m_port_idx, 1); + + try { + editor.find_output_edge(EditorNode{99}, EditorOutput{"conv1/7x7_s2_1"}); + } catch (const std::exception& e) { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos); + } +} + NGRAPH_TEST(onnx_editor, editor_api_select_edge_const_network) { ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests_2.onnx")}; @@ -1044,6 +1079,12 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) { is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{"relu1_name"}); EXPECT_EQ(is_correct_node, true); + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{2}); + EXPECT_EQ(is_correct_node, true); + + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{99}); + EXPECT_EQ(is_correct_node, false); + is_correct_node = editor.is_correct_and_unambiguous_node(EditorNode{EditorOutput{"in3"}}); EXPECT_EQ(is_correct_node, false); @@ -1054,6 +1095,32 @@ NGRAPH_TEST(onnx_editor, editor_api_is_correct_and_unambiguous_node) { EXPECT_EQ(is_correct_node, false); } +NGRAPH_TEST(onnx_editor, editor_api_get_node_index) { + ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph_extraction_tests.onnx")}; + + EXPECT_EQ(editor.get_node_index(EditorNode{2}), 2); + EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"relu1"}}), 0); + EXPECT_EQ(editor.get_node_index(EditorNode{EditorOutput{"split2"}}), 5); + EXPECT_EQ(editor.get_node_index(EditorNode{"relu1_name"}), 0); + + try { + editor.get_node_index(EditorNode{99}); + } catch (const std::exception& e) { + std::string msg{e.what()}; + EXPECT_TRUE(msg.find("Provided node index: 99 is out of scope") != std::string::npos); + } + + try { + editor.get_node_index(EditorNode{"add_ambiguous_name"}); + } catch (const std::exception& e) { + std::string msg{e.what()}; + EXPECT_TRUE( + msg.find( + "The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") != + std::string::npos); + } +} + NGRAPH_TEST(onnx_editor, editor_api_input_edge_from_tensor_with_single_consumer) { ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/add_ab.onnx")}; @@ -1365,15 +1432,18 @@ NGRAPH_TEST(onnx_editor, get_input_ports) { editor.get_input_ports(EditorNode{"add_ambiguous_name"}); } catch (const std::exception& e) { std::string msg{e.what()}; - EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") != - std::string::npos); + EXPECT_TRUE( + msg.find( + "The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") != + std::string::npos); } try { editor.get_input_ports(EditorNode{""}); } catch (const std::exception& e) { std::string msg{e.what()}; - EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") != - std::string::npos); + EXPECT_TRUE( + msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") != + std::string::npos); } } NGRAPH_TEST(onnx_editor, get_output_ports) { @@ -1393,14 +1463,17 @@ NGRAPH_TEST(onnx_editor, get_output_ports) { editor.get_output_ports(EditorNode{"add_ambiguous_name"}); } catch (const std::exception& e) { std::string msg{e.what()}; - EXPECT_TRUE(msg.find("The node with name: add_ambiguous_name, output_name: not_given is ambiguous") != - std::string::npos); + EXPECT_TRUE( + msg.find( + "The node with name: add_ambiguous_name, output_name: not_given, node_index: not_given is ambiguous") != + std::string::npos); } try { editor.get_output_ports(EditorNode{""}); } catch (const std::exception& e) { std::string msg{e.what()}; - EXPECT_TRUE(msg.find("The node with name: not_given, output_name: not_given is ambiguous") != - std::string::npos); + EXPECT_TRUE( + msg.find("The node with name: not_given, output_name: not_given, node_index: not_given is ambiguous") != + std::string::npos); } } diff --git a/runtime/bindings/python/src/compatibility/pyngraph/frontend/place.cpp b/runtime/bindings/python/src/compatibility/pyngraph/frontend/place.cpp index 79c88a76bf5..60bf74528ab 100644 --- a/runtime/bindings/python/src/compatibility/pyngraph/frontend/place.cpp +++ b/runtime/bindings/python/src/compatibility/pyngraph/frontend/place.cpp @@ -89,13 +89,23 @@ void regclass_pyngraph_Place(py::module m) { place.def( "get_consuming_operations", - [](const ngraph::frontend::Place& self, py::object outputPortIndex) { - if (outputPortIndex == py::none()) { - return self.get_consuming_operations(); + [](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) { + if (outputName == py::none()) { + if (outputPortIndex == py::none()) { + return self.get_consuming_operations(); + } else { + return self.get_consuming_operations(py::cast(outputPortIndex)); + } } else { - return self.get_consuming_operations(py::cast(outputPortIndex)); + if (outputPortIndex == py::none()) { + return self.get_consuming_operations(py::cast(outputName)); + } else { + return self.get_consuming_operations(py::cast(outputName), + py::cast(outputPortIndex)); + } } }, + py::arg("outputName") = py::none(), py::arg("outputPortIndex") = py::none(), R"( Returns references to all operation nodes that consume data from this place for specified output port. @@ -103,6 +113,8 @@ void regclass_pyngraph_Place(py::module m) { Parameters ---------- + outputName : str + Name of output port group. May not be set if node has one output port group. outputPortIndex : int If place is an operational node it specifies which output port should be considered May not be set if node has only one output port. @@ -115,13 +127,22 @@ void regclass_pyngraph_Place(py::module m) { place.def( "get_target_tensor", - [](const ngraph::frontend::Place& self, py::object outputPortIndex) { - if (outputPortIndex == py::none()) { - return self.get_target_tensor(); + [](const ngraph::frontend::Place& self, py::object outputName, py::object outputPortIndex) { + if (outputName == py::none()) { + if (outputPortIndex == py::none()) { + return self.get_target_tensor(); + } else { + return self.get_target_tensor(py::cast(outputPortIndex)); + } } else { - return self.get_target_tensor(py::cast(outputPortIndex)); + if (outputPortIndex == py::none()) { + return self.get_target_tensor(py::cast(outputName)); + } else { + return self.get_target_tensor(py::cast(outputName), py::cast(outputPortIndex)); + } } }, + py::arg("outputName") = py::none(), py::arg("outputPortIndex") = py::none(), R"( Returns a tensor place that gets data from this place; applicable for operations, @@ -129,6 +150,8 @@ void regclass_pyngraph_Place(py::module m) { Parameters ---------- + outputName : str + Name of output port group. May not be set if node has one output port group. outputPortIndex : int Output port index if the current place is an operation node and has multiple output ports. May not be set if place has only one output port. @@ -141,19 +164,31 @@ void regclass_pyngraph_Place(py::module m) { place.def( "get_producing_operation", - [](const ngraph::frontend::Place& self, py::object inputPortIndex) { - if (inputPortIndex == py::none()) { - return self.get_producing_operation(); + [](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) { + if (inputName == py::none()) { + if (inputPortIndex == py::none()) { + return self.get_producing_operation(); + } else { + return self.get_producing_operation(py::cast(inputPortIndex)); + } } else { - return self.get_producing_operation(py::cast(inputPortIndex)); + if (inputPortIndex == py::none()) { + return self.get_producing_operation(py::cast(inputName)); + } else { + return self.get_producing_operation(py::cast(inputName), + py::cast(inputPortIndex)); + } } }, + py::arg("inputName") = py::none(), py::arg("inputPortIndex") = py::none(), R"( Get an operation node place that immediately produces data for this place. Parameters ---------- + inputName : str + Name of port group. May not be set if node has one input port group. inputPortIndex : int If a given place is itself an operation node, this specifies a port index. May not be set if place has only one input port. @@ -260,13 +295,22 @@ void regclass_pyngraph_Place(py::module m) { place.def( "get_source_tensor", - [](const ngraph::frontend::Place& self, py::object inputPortIndex) { - if (inputPortIndex == py::none()) { - return self.get_source_tensor(); + [](const ngraph::frontend::Place& self, py::object inputName, py::object inputPortIndex) { + if (inputName == py::none()) { + if (inputPortIndex == py::none()) { + return self.get_source_tensor(); + } else { + return self.get_source_tensor(py::cast(inputPortIndex)); + } } else { - return self.get_source_tensor(py::cast(inputPortIndex)); + if (inputPortIndex == py::none()) { + return self.get_source_tensor(py::cast(inputName)); + } else { + return self.get_source_tensor(py::cast(inputName), py::cast(inputPortIndex)); + } } }, + py::arg("inputName") = py::none(), py::arg("inputPortIndex") = py::none(), R"( Returns a tensor place that supplies data for this place; applicable for operations, @@ -274,6 +318,8 @@ void regclass_pyngraph_Place(py::module m) { Parameters ---------- + inputName : str + Name of port group. May not be set if node has one input port group. inputPortIndex : int Input port index for operational node. May not be specified if place has only one input port. diff --git a/runtime/bindings/python/tests/mock/mock_py_ngraph_frontend/mock_py_frontend.hpp b/runtime/bindings/python/tests/mock/mock_py_ngraph_frontend/mock_py_frontend.hpp index 5e089770ec2..3143aa96fda 100644 --- a/runtime/bindings/python/tests/mock/mock_py_ngraph_frontend/mock_py_frontend.hpp +++ b/runtime/bindings/python/tests/mock/mock_py_ngraph_frontend/mock_py_frontend.hpp @@ -107,12 +107,29 @@ public: std::vector get_consuming_operations() const override { m_stat.m_get_consuming_operations++; m_stat.m_lastArgInt = -1; + m_stat.m_lastArgString = ""; return {std::make_shared()}; } std::vector get_consuming_operations(int outputPortIndex) const override { m_stat.m_get_consuming_operations++; m_stat.m_lastArgInt = outputPortIndex; + m_stat.m_lastArgString = ""; + return {std::make_shared()}; + } + + std::vector get_consuming_operations(const std::string& outputName) const override { + m_stat.m_get_consuming_operations++; + m_stat.m_lastArgInt = -1; + m_stat.m_lastArgString = outputName; + return {std::make_shared()}; + } + + std::vector get_consuming_operations(const std::string& outputName, + int outputPortIndex) const override { + m_stat.m_get_consuming_operations++; + m_stat.m_lastArgInt = outputPortIndex; + m_stat.m_lastArgString = outputName; return {std::make_shared()}; } @@ -128,6 +145,20 @@ public: return std::make_shared(); } + Place::Ptr get_target_tensor(const std::string& outputName) const override { + m_stat.m_get_target_tensor++; + m_stat.m_lastArgInt = -1; + m_stat.m_lastArgString = outputName; + return {std::make_shared()}; + } + + Place::Ptr get_target_tensor(const std::string& outputName, int outputPortIndex) const override { + m_stat.m_get_target_tensor++; + m_stat.m_lastArgInt = outputPortIndex; + m_stat.m_lastArgString = outputName; + return {std::make_shared()}; + } + Place::Ptr get_producing_operation() const override { m_stat.m_get_producing_operation++; m_stat.m_lastArgInt = -1; @@ -140,6 +171,20 @@ public: return std::make_shared(); } + Place::Ptr get_producing_operation(const std::string& inputName) const override { + m_stat.m_get_producing_operation++; + m_stat.m_lastArgInt = -1; + m_stat.m_lastArgString = inputName; + return {std::make_shared()}; + } + + Place::Ptr get_producing_operation(const std::string& inputName, int inputPortIndex) const override { + m_stat.m_get_producing_operation++; + m_stat.m_lastArgInt = inputPortIndex; + m_stat.m_lastArgString = inputName; + return {std::make_shared()}; + } + Place::Ptr get_producing_port() const override { m_stat.m_get_producing_port++; return std::make_shared(); @@ -236,6 +281,20 @@ public: return {std::make_shared()}; } + Place::Ptr get_source_tensor(const std::string& inputName) const override { + m_stat.m_get_source_tensor++; + m_stat.m_lastArgInt = -1; + m_stat.m_lastArgString = inputName; + return {std::make_shared()}; + } + + Place::Ptr get_source_tensor(const std::string& inputName, int inputPortIndex) const override { + m_stat.m_get_source_tensor++; + m_stat.m_lastArgInt = inputPortIndex; + m_stat.m_lastArgString = inputName; + return {std::make_shared()}; + } + //---------------Stat-------------------- PlaceStat get_stat() const { return m_stat; diff --git a/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py b/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py index e241945dd0c..586cdd577a5 100644 --- a/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py +++ b/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py @@ -9,6 +9,7 @@ from ngraph import PartialShape from ngraph.frontend import FrontEndManager +# ------Test input model 1------ # in1 in2 in3 # | | | # \ / | @@ -24,9 +25,32 @@ from ngraph.frontend import FrontEndManager # / \ | # out1 out2 out4 # +# +# ------Test input model 2------ +# in1 in2 +# | | +# \ / +# +--------+ +# | Add | +# +--------+ +# +# | +# +--------+ +# | Split | +# |(split2)| +# +--------+ +# / \ +# +# +-------+ +-------+ +# | Abs | | Sin | +# | (abs1)| | | +# +------ + +-------+ +# | | +# out1 out2 +# def create_test_onnx_models(): models = {} - # Input model + # Input model 1 add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"]) split = onnx.helper.make_node("Split", inputs=["add_out"], outputs=["out1", "out2"], name="split1", axis=0) @@ -48,6 +72,24 @@ def create_test_onnx_models(): models["input_model.onnx"] = make_model(graph, producer_name="ONNX Importer", opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Input model 2 + split_2 = onnx.helper.make_node("Split", inputs=["add_out"], + outputs=["sp_out1", "sp_out2"], name="split2", axis=0) + abs = onnx.helper.make_node("Abs", inputs=["sp_out1"], outputs=["out1"], name="abs1") + sin = onnx.helper.make_node("Sin", inputs=["sp_out2"], outputs=["out2"]) + + input_tensors = [ + make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)), + ] + output_tensors = [ + make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)), + make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)), + ] + graph = make_graph([add, split_2, abs, sin], "test_graph_2", input_tensors, output_tensors) + models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer", + opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Expected for extract_subgraph input_tensors = [ make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), @@ -312,8 +354,9 @@ def test_extract_subgraph_4(): model = fe.load("input_model.onnx") assert model - place1 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) - place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1) + out4_tensor = model.get_place_by_tensor_name(tensorName="out4") + place1 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) + place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1) place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) place4 = model.get_place_by_tensor_name(tensorName="out1") place5 = model.get_place_by_tensor_name(tensorName="out2") @@ -377,8 +420,9 @@ def test_override_all_inputs(): place1 = model.get_place_by_operation_name_and_input_port( operationName="split1", inputPortIndex=0) - place2 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) - place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1) + out4_tensor = model.get_place_by_tensor_name(tensorName="out4") + place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) + place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1) place4 = model.get_place_by_tensor_name(tensorName="in3") model.override_all_inputs(inputs=[place1, place2, place3, place4]) result_func = fe.convert(model) @@ -432,11 +476,15 @@ def test_is_input_output(): assert not place3.is_input() assert not place3.is_output() - place4 = place1 = model.get_place_by_operation_name_and_input_port( + place4 = model.get_place_by_operation_name_and_input_port( operationName="split1", inputPortIndex=0) assert not place4.is_input() assert not place4.is_output() + place5 = model.get_place_by_operation_name(operationName="split1") + assert not place5.is_input() + assert not place5.is_output() + def test_set_partial_shape(): skip_if_onnx_frontend_is_disabled() @@ -520,12 +568,14 @@ def test_is_equal(): place2 = model.get_place_by_tensor_name(tensorName="out2") assert place2.is_equal(place2) - place3 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) - place4 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) + out4_tensor = model.get_place_by_tensor_name(tensorName="out4") + place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) + place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) assert place3.is_equal(place4) + out1_tensor = model.get_place_by_tensor_name(tensorName="out1") place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) - place6 = model.get_place_by_tensor_name(tensorName="out1").get_input_port(inputPortIndex=0) + place6 = out1_tensor.get_producing_operation().get_input_port(inputPortIndex=0) assert place5.is_equal(place6) place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port() @@ -538,6 +588,10 @@ def test_is_equal(): assert not place6.is_equal(place7) assert not place8.is_equal(place2) + place9 = model.get_place_by_operation_name(operationName="split1") + assert place2.get_producing_operation().is_equal(place9) + assert not place9.is_equal(place2) + def test_is_equal_data(): skip_if_onnx_frontend_is_disabled() @@ -560,11 +614,12 @@ def test_is_equal_data(): place4 = place2.get_producing_port() assert place2.is_equal_data(place4) - place5 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=0) + out4_tensor = model.get_place_by_tensor_name(tensorName="out4") + place5 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) assert place2.is_equal_data(place5) assert place4.is_equal_data(place5) - place6 = model.get_place_by_tensor_name(tensorName="out4").get_input_port(inputPortIndex=1) + place6 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1) assert place6.is_equal_data(place5) place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) @@ -648,3 +703,198 @@ def test_get_input_port(): assert not split_op.get_input_port(inputPortIndex=1) assert not split_op.get_input_port(inputName="not_existed") + + +def test_get_consuming_ports(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model.onnx") + assert model + + place1 = model.get_place_by_tensor_name(tensorName="add_out") + add_tensor_consuming_ports = place1.get_consuming_ports() + assert len(add_tensor_consuming_ports) == 3 + place2 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) + assert add_tensor_consuming_ports[0].is_equal(place2) + out4_tensor = model.get_place_by_tensor_name(tensorName="out4") + place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) + assert add_tensor_consuming_ports[1].is_equal(place3) + place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1) + assert add_tensor_consuming_ports[2].is_equal(place4) + + add_op_consuming_ports = place1.get_producing_operation().get_consuming_ports() + assert len(add_op_consuming_ports) == len(add_tensor_consuming_ports) + for i in range(len(add_op_consuming_ports)): + assert add_op_consuming_ports[i].is_equal(add_tensor_consuming_ports[i]) + + +def test_get_consuming_ports_2(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model_2.onnx") + assert model + + split_op = model.get_place_by_operation_name(operationName="split2") + split_op_consuming_ports = split_op.get_consuming_ports() + assert len(split_op_consuming_ports) == 2 + abs_input_port = model.get_place_by_operation_name(operationName="abs1").get_input_port(inputPortIndex=0) + assert split_op_consuming_ports[0].is_equal(abs_input_port) + out2_tensor = model.get_place_by_tensor_name(tensorName="out2") + sin_input_port = out2_tensor.get_producing_operation().get_input_port(inputPortIndex=0) + assert split_op_consuming_ports[1].is_equal(sin_input_port) + + split_out_port_0 = split_op.get_output_port(outputPortIndex=0) + split_out_port_0_consuming_ports = split_out_port_0.get_consuming_ports() + assert len(split_out_port_0_consuming_ports) == 1 + assert split_out_port_0_consuming_ports[0].is_equal(abs_input_port) + + +def test_get_producing_operation(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model_2.onnx") + assert model + + split_tensor_out_2 = model.get_place_by_tensor_name(tensorName="sp_out2") + split_op = model.get_place_by_operation_name(operationName="split2") + assert split_tensor_out_2.get_producing_operation().is_equal(split_op) + + split_op = model.get_place_by_operation_name(operationName="split2") + split_out_port_2 = split_op.get_output_port(outputPortIndex=1) + assert split_out_port_2.get_producing_operation().is_equal(split_op) + + +def test_get_producing_operation_2(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model_2.onnx") + assert model + + abs_op = model.get_place_by_operation_name(operationName="abs1") + abs_port_0 = abs_op.get_input_port() + split_op = model.get_place_by_operation_name(operationName="split2") + assert abs_port_0.get_producing_operation().is_equal(split_op) + assert abs_op.get_producing_operation().is_equal(split_op) + + add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out") + add_op = add_out_tensor.get_producing_operation() + assert not add_op.get_producing_operation() + + split_op_producing_op = split_op.get_producing_operation(inputName="add_out") + assert split_op_producing_op.is_equal(add_op) + + out2_tensor = model.get_place_by_tensor_name(tensorName="out2") + sin_op = out2_tensor.get_producing_operation() + assert sin_op.get_producing_operation(inputPortIndex=0).is_equal(split_op) + + +def test_get_consuming_operations(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model_2.onnx") + assert model + + split_op = model.get_place_by_operation_name(operationName="split2") + split_op_consuming_ops = split_op.get_consuming_operations() + abs_op = model.get_place_by_operation_name(operationName="abs1") + sin_op = model.get_place_by_tensor_name(tensorName="out2").get_producing_operation() + + assert len(split_op_consuming_ops) == 2 + assert split_op_consuming_ops[0].is_equal(abs_op) + assert split_op_consuming_ops[1].is_equal(sin_op) + + split_op_port = split_op.get_input_port(inputPortIndex=0) + split_op_port_consuming_ops = split_op_port.get_consuming_operations() + + assert len(split_op_port_consuming_ops) == 1 + assert split_op_port_consuming_ops[0].is_equal(split_op) + + add_out_port = model.get_place_by_tensor_name(tensorName="add_out").get_producing_port() + add_out_port_consuming_ops = add_out_port.get_consuming_operations() + assert len(add_out_port_consuming_ops) == 1 + assert add_out_port_consuming_ops[0].is_equal(split_op) + + sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2") + sp_out2_tensor_consuming_ops = sp_out2_tensor.get_consuming_operations() + assert len(sp_out2_tensor_consuming_ops) == 1 + assert sp_out2_tensor_consuming_ops[0].is_equal(sin_op) + + out2_tensor = model.get_place_by_tensor_name(tensorName="out2") + out2_tensor_consuming_ops = out2_tensor.get_consuming_operations() + assert len(out2_tensor_consuming_ops) == 0 + out2_port_consuming_ops = out2_tensor.get_producing_port().get_consuming_operations() + assert len(out2_port_consuming_ops) == 0 + + split_out_1_consuming_ops = split_op.get_consuming_operations(outputPortIndex=1) + assert len(split_out_1_consuming_ops) == 1 + split_out_sp_out_2_consuming_ops = split_op.get_consuming_operations(outputName="sp_out2") + assert len(split_out_sp_out_2_consuming_ops) == 1 + assert split_out_1_consuming_ops[0].is_equal(split_out_sp_out_2_consuming_ops[0]) + assert split_out_1_consuming_ops[0].is_equal(sin_op) + + +def test_get_target_tensor(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model_2.onnx") + assert model + + split_op = model.get_place_by_operation_name(operationName="split2") + assert not split_op.get_target_tensor() + + split_op_tensor_1 = split_op.get_target_tensor(outputPortIndex=1) + sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2") + assert split_op_tensor_1.is_equal(sp_out2_tensor) + + split_tensor_sp_out2 = split_op.get_target_tensor(outputName="sp_out2") + assert split_tensor_sp_out2.is_equal(split_op_tensor_1) + + abs_op = model.get_place_by_operation_name(operationName="abs1") + out1_tensor = model.get_place_by_tensor_name(tensorName="out1") + assert abs_op.get_target_tensor().is_equal(out1_tensor) + + +def test_get_source_tensor(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model_2.onnx") + assert model + + add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out") + add_op = add_out_tensor.get_producing_operation() + assert not add_op.get_source_tensor() + + add_op_in_tensor_1 = add_op.get_source_tensor(inputPortIndex=1) + in2_tensor = model.get_place_by_tensor_name(tensorName="in2") + assert add_op_in_tensor_1.is_equal(in2_tensor) + + add_op_in_tensor_in2 = add_op.get_source_tensor(inputName="in2") + assert add_op_in_tensor_in2.is_equal(in2_tensor) + + split_op = model.get_place_by_operation_name(operationName="split2") + assert split_op.get_source_tensor().is_equal(add_out_tensor) + + +def test_get_producing_port(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + model = fe.load("input_model_2.onnx") + assert model + + split_op = model.get_place_by_operation_name(operationName="split2") + split_op_in_port = split_op.get_input_port() + split_op_in_port_prod_port = split_op_in_port.get_producing_port() + + add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out") + add_op = add_out_tensor.get_producing_operation() + add_op_out_port = add_op.get_output_port() + + assert split_op_in_port_prod_port.is_equal(add_op_out_port) diff --git a/runtime/bindings/python/tests/test_frontend/test_frontendmanager.py b/runtime/bindings/python/tests/test_frontend/test_frontendmanager.py index 920746e200a..8ecece7404a 100644 --- a/runtime/bindings/python/tests/test_frontend/test_frontendmanager.py +++ b/runtime/bindings/python/tests/test_frontend/test_frontendmanager.py @@ -440,6 +440,16 @@ def test_place_get_consuming_operations(): stat = get_place_stat(place) assert stat.get_consuming_operations == 2 assert stat.lastArgInt == -1 + assert place.get_consuming_operations(outputName="2") is not None + stat = get_place_stat(place) + assert stat.get_consuming_operations == 3 + assert stat.lastArgInt == -1 + assert stat.lastArgString == "2" + assert place.get_consuming_operations(outputName="3", outputPortIndex=33) is not None + stat = get_place_stat(place) + assert stat.get_consuming_operations == 4 + assert stat.lastArgInt == 33 + assert stat.lastArgString == "3" @mock_needed @@ -453,6 +463,16 @@ def test_place_get_target_tensor(): stat = get_place_stat(place) assert stat.get_target_tensor == 2 assert stat.lastArgInt == -1 + assert place.get_target_tensor(outputName="2") is not None + stat = get_place_stat(place) + assert stat.get_target_tensor == 3 + assert stat.lastArgInt == -1 + assert stat.lastArgString == "2" + assert place.get_target_tensor(outputName="3", outputPortIndex=33) is not None + stat = get_place_stat(place) + assert stat.get_target_tensor == 4 + assert stat.lastArgInt == 33 + assert stat.lastArgString == "3" @mock_needed @@ -466,6 +486,16 @@ def test_place_get_producing_operation(): stat = get_place_stat(place) assert stat.get_producing_operation == 2 assert stat.lastArgInt == -1 + assert place.get_producing_operation(inputName="2") is not None + stat = get_place_stat(place) + assert stat.get_producing_operation == 3 + assert stat.lastArgInt == -1 + assert stat.lastArgString == "2" + assert place.get_producing_operation(inputName="3", inputPortIndex=33) is not None + stat = get_place_stat(place) + assert stat.get_producing_operation == 4 + assert stat.lastArgInt == 33 + assert stat.lastArgString == "3" @mock_needed @@ -551,3 +581,13 @@ def test_place_get_source_tensor(): stat = get_place_stat(place) assert stat.get_source_tensor == 2 assert stat.lastArgInt == 22 + assert place.get_source_tensor(inputName="2") is not None + stat = get_place_stat(place) + assert stat.get_source_tensor == 3 + assert stat.lastArgInt == -1 + assert stat.lastArgString == "2" + assert place.get_source_tensor(inputName="3", inputPortIndex=33) is not None + stat = get_place_stat(place) + assert stat.get_source_tensor == 4 + assert stat.lastArgInt == 33 + assert stat.lastArgString == "3"