From de4ceba375dab5d372f2f29c0efbe1bf550b21d0 Mon Sep 17 00:00:00 2001 From: Tomasz Jankowski Date: Tue, 2 Nov 2021 09:35:04 +0100 Subject: [PATCH] [ONNX FE] Add implementation of Frontend API methods for naming and annotation (#8026) --- .../core/include/openvino/core/function.hpp | 5 +- .../include/frontend_manager/input_model.hpp | 4 +- ngraph/frontend/onnx/frontend/src/editor.cpp | 122 ++++++++++++- ngraph/frontend/onnx/frontend/src/editor.hpp | 41 ++++- .../onnx/frontend/src/editor_types.hpp | 6 +- .../onnx/frontend/src/input_model.cpp | 33 +++- .../onnx/frontend/src/input_model.hpp | 13 +- ngraph/frontend/onnx/frontend/src/place.cpp | 18 +- ngraph/frontend/onnx/frontend/src/place.hpp | 7 +- .../pyngraph/frontend/inputmodel.cpp | 2 +- .../test_frontend_onnx_editor.py | 172 ++++++++++++++++++ 11 files changed, 402 insertions(+), 21 deletions(-) diff --git a/ngraph/core/include/openvino/core/function.hpp b/ngraph/core/include/openvino/core/function.hpp index 7741b3efa14..f8eed3450b1 100644 --- a/ngraph/core/include/openvino/core/function.hpp +++ b/ngraph/core/include/openvino/core/function.hpp @@ -282,11 +282,12 @@ public: return m_rt_info; } -private: Function(const Function&) = delete; - Function(const Function&&) = delete; + Function(Function&&) = delete; Function& operator=(const Function&) = delete; + Function& operator=(Function&&) = delete; +private: /// \brief Depending on the options selected, /// checks all the Parameter/Variables are registered in the list of Function /// parameters/variables or finds all Parameters/Variables in a function and registers them. diff --git a/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp b/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp index 40b329abecf..6e930e4ecfb 100644 --- a/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp +++ b/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp @@ -93,12 +93,12 @@ public: ///// Naming and annotation ///// /// \brief Sets name for tensor. Overwrites existing names of this place - /// \param operation Tensor place + /// \param tensor Tensor place /// \param new_name New name for this tensor virtual void set_name_for_tensor(Place::Ptr tensor, const std::string& new_name); /// \brief Adds new name for tensor - /// \param operation Tensor place + /// \param tensor Tensor place /// \param new_name New name to be added to this place virtual void add_name_for_tensor(Place::Ptr tensor, const std::string& new_name); diff --git a/ngraph/frontend/onnx/frontend/src/editor.cpp b/ngraph/frontend/onnx/frontend/src/editor.cpp index f3761ebd42c..641f43b144a 100644 --- a/ngraph/frontend/onnx/frontend/src/editor.cpp +++ b/ngraph/frontend/onnx/frontend/src/editor.cpp @@ -57,6 +57,16 @@ TensorProto* find_graph_initializer(GraphProto& graph, const std::string& name) return nullptr; } +ValueInfoProto* find_graph_value_info(GraphProto& graph, const std::string& name) { + for (int i = 0; i < graph.value_info_size(); ++i) { + auto value_info = graph.mutable_value_info(i); + if (value_info->name() == name) { + return value_info; + } + } + return nullptr; +} + void modify_input_type(ValueInfoProto& onnx_input, const element::Type_t elem_type) { if (!onnx_input.has_type()) { throw ngraph_error("The input is malformed - it doesn't contain the 'type' field. Cannot change the " @@ -272,12 +282,17 @@ void onnx_editor::ONNXModelEditor::set_input_shapes(const std::mapm_model_proto->mutable_graph(); + const TensorProto* tensor = nullptr; + const auto onnx_graph = m_pimpl->m_model_proto->mutable_graph(); InferShapesAutoRelease onnx_shapes(m_pimpl->m_model_proto); - if (const auto* input = find_graph_input(*onnx_graph, tensor_name)) { + if (const auto input = find_graph_input(*onnx_graph, tensor_name)) { value_info = input; - } else if (const auto* output = find_graph_output(*onnx_graph, tensor_name)) { + } else if (const auto output = find_graph_output(*onnx_graph, tensor_name)) { value_info = output; + } else if (const auto val_info = find_graph_value_info(*onnx_graph, tensor_name)) { + value_info = val_info; + } else if (const auto initializer = find_graph_initializer(*onnx_graph, tensor_name)) { + tensor = initializer; } else { try { onnx_shapes.infer_shapes(); @@ -301,6 +316,8 @@ PartialShape onnx_editor::ONNXModelEditor::get_tensor_shape(const std::string& t } else { return PartialShape::dynamic(); } + } else if (tensor) { + return PartialShape{Shape{tensor->dims().cbegin(), tensor->dims().cend()}}; } else { throw ngraph_error("The tensor: " + tensor_name + " was not found in the graph"); } @@ -412,6 +429,105 @@ void onnx_editor::ONNXModelEditor::set_input_values( } } +void onnx_editor::ONNXModelEditor::set_tensor_name(const std::string& current_name, const std::string& new_name) { + OPENVINO_ASSERT(!new_name.empty(), "New name must not be empty."); + + const auto graph = m_pimpl->m_model_proto->mutable_graph(); + + OPENVINO_ASSERT(!(find_graph_input(*graph, new_name) || find_graph_output(*graph, new_name) || + find_graph_initializer(*graph, new_name) || find_graph_value_info(*graph, new_name) || + m_pimpl->m_edge_mapper.is_correct_tensor_name(new_name)), + "The name '", + new_name, + "' is already used by another tensor."); + + m_pimpl->m_is_mapper_updated = false; + + // the same tensor can be multiplied in any or all of below arrays + if (const auto initializer = find_graph_initializer(*graph, current_name)) + *initializer->mutable_name() = new_name; + if (const auto input = find_graph_input(*graph, current_name)) + *input->mutable_name() = new_name; + if (const auto output = find_graph_output(*graph, current_name)) + *output->mutable_name() = new_name; + if (const auto value_info = find_graph_value_info(*graph, current_name)) + *value_info->mutable_name() = new_name; + + for (size_t i = 0; i < graph->node().size(); ++i) { + const auto node = graph->mutable_node(i); + + bool output_found = false; + for (size_t j = 0; j < node->output().size(); ++j) + if (node->output(j) == current_name) { + *node->mutable_output(j) = new_name; + output_found = true; + break; + } + if (output_found) + continue; + + for (size_t j = 0; j < node->input().size(); ++j) + if (node->input(j) == current_name) + *node->mutable_input(j) = new_name; + } +} + +void onnx_editor::ONNXModelEditor::set_node_name(const EditorNode& node, const std::string& new_name) { + const auto node_idx = m_pimpl->m_edge_mapper.get_node_index(node); + const auto graph = m_pimpl->m_model_proto->mutable_graph(); + + m_pimpl->m_is_mapper_updated = false; + + *graph->mutable_node(node_idx)->mutable_name() = new_name; +} + +void onnx_editor::ONNXModelEditor::clear_nodes_name(const std::string& name) { + const auto graph = m_pimpl->m_model_proto->mutable_graph(); + + m_pimpl->m_is_mapper_updated = false; + + for (size_t i = 0; i < graph->node().size(); ++i) { + const auto node = graph->mutable_node(i); + if (node->has_name() && node->name() == name) + node->clear_name(); + } +} + +void onnx_editor::ONNXModelEditor::set_name_for_dimension(const std::string& node_name, + size_t shape_dim_index, + const std::string& dim_name) { + OPENVINO_ASSERT(!dim_name.empty(), "Dimension name must not be empty."); + + const auto graph = m_pimpl->m_model_proto->mutable_graph(); + + OPENVINO_ASSERT(!find_graph_initializer(*graph, node_name), "ONNX initializer shape dimension cannot be dynamic."); + + // the same tensor can be multiplied in any or all of below arrays + const auto input = find_graph_input(*graph, node_name); + const auto output = find_graph_output(*graph, node_name); + const auto value_info = find_graph_value_info(*graph, node_name); + OPENVINO_ASSERT(input || output || value_info, "There is no tensor named '", node_name, "' in the graph."); + + const auto set_dim_param = [&shape_dim_index, &dim_name](ValueInfoProto* tensor) { + const auto shape = tensor->mutable_type()->mutable_tensor_type()->mutable_shape(); + auto shape_dim_size = shape->dim_size(); + + for (; shape_dim_size <= shape_dim_index; ++shape_dim_size) + add_dim_to_onnx_shape(Dimension::dynamic(), *shape); + + shape->mutable_dim(shape_dim_index)->set_dim_param(dim_name.c_str()); + }; + + m_pimpl->m_is_mapper_updated = false; + + if (input) + set_dim_param(input); + if (output) + set_dim_param(output); + if (value_info) + set_dim_param(value_info); +} + void onnx_editor::ONNXModelEditor::update_mapper_if_needed() const { if (!m_pimpl->m_is_mapper_updated) { m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto->graph()); diff --git a/ngraph/frontend/onnx/frontend/src/editor.hpp b/ngraph/frontend/onnx/frontend/src/editor.hpp index a1ba6bf810e..c288fa4ccd0 100644 --- a/ngraph/frontend/onnx/frontend/src/editor.hpp +++ b/ngraph/frontend/onnx/frontend/src/editor.hpp @@ -24,8 +24,6 @@ namespace onnx_editor { /// model's input types and shapes, extract a subgraph and more. class ONNX_IMPORTER_API ONNXModelEditor final { public: - ONNXModelEditor() = delete; - /// \brief Creates an editor from a model file located on a storage device. The file /// is parsed and loaded into the m_model_proto member variable. /// @@ -71,7 +69,7 @@ public: /// the underlying ModelProto is modified - obsolete inputs, initializers, nodes /// and outputs are removed from the in-memory model. /// - /// \node Please look at the declaration of InputEdge and OutputEdge for explanation + /// \note Please look at the declaration of InputEdge and OutputEdge for explanation /// how those objects can be created. If the outputs parameter is empty /// this method keeps all of the original outputs of the model. /// @@ -92,6 +90,41 @@ public: /// overwritten. void set_input_values(const std::map>& input_values); + /// \brief Changes the name of given tensor. + /// + /// \note It changes input, output, initializer and value_info proto repeated fields as well as + /// all nodes which refer to the tensor. + /// + /// \param current_name Name of tensor to be changed. + /// \param new_name New name of tensor. Must not be empty nor point to existing tensor (including self). + void set_tensor_name(const std::string& current_name, const std::string& new_name); + + /// \brief Sets node's name. + /// + /// \note Empty name is accepted. + /// + /// \param node Handle to node. + /// \param new_name New name of the node. + void set_node_name(const EditorNode& node, const std::string& new_name); + + /// \brief Removes node name for all nodes with given name. + /// + /// \note Empty and not present names are accepted. + /// + /// \param name Name to clear + void clear_nodes_name(const std::string& name); + + /// \brief Overrides or creates name for tensor shape dimension (numeric dimension is erased). + /// + /// \note It changes input, output and value_info proto repeated fields. + /// If rank of the tensor is too low the shape is expanded with dynamic dimensions so + /// the name can be set at specified position. + /// + /// \param node_name Tensor name to change its shape. Must not point to initializer. + /// \param shape_dim_index Index of dimension to change. + /// \param dim_name New name of the dimension. Must not be empty. + void set_name_for_dimension(const std::string& node_name, size_t shape_dim_index, const std::string& dim_name); + /// \brief Returns a serialized ONNX model, possibly modified by the editor. std::string model_string() const; @@ -126,7 +159,7 @@ public: /// std::string get_target_tensor_name(const OutputEdge& edge) const; - /// \brief Returns true if output edge is input of the model. Otherwise false. + /// \brief Returns true if output edge is output of the model. Otherwise false. bool is_output(const OutputEdge& edge) const; /// \brief Returns the path to the original model file diff --git a/ngraph/frontend/onnx/frontend/src/editor_types.hpp b/ngraph/frontend/onnx/frontend/src/editor_types.hpp index febbbb01449..c9c7e02408f 100644 --- a/ngraph/frontend/onnx/frontend/src/editor_types.hpp +++ b/ngraph/frontend/onnx/frontend/src/editor_types.hpp @@ -119,9 +119,9 @@ struct EditorNode { EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {} EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {} EditorNode(const int node_index) : m_node_index{node_index} {} - const std::string m_node_name = ""; - const std::string m_output_name = ""; - const int m_node_index = -1; + std::string m_node_name = ""; + std::string m_output_name = ""; + int m_node_index = -1; }; } // namespace onnx_editor } // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/input_model.cpp b/ngraph/frontend/onnx/frontend/src/input_model.cpp index ac136915bcf..b61e8fe19c6 100644 --- a/ngraph/frontend/onnx/frontend/src/input_model.cpp +++ b/ngraph/frontend/onnx/frontend/src/input_model.cpp @@ -60,7 +60,8 @@ Place::Ptr InputModelONNX::get_place_by_tensor_name(const std::string& tensor_na Place::Ptr InputModelONNX::get_place_by_operation_name(const std::string& operation_name) const { if (m_editor->is_correct_and_unambiguous_node(operation_name)) { - return std::make_shared(onnx_editor::EditorNode{operation_name}, m_editor); + const auto node_index = m_editor->get_node_index(onnx_editor::EditorNode{operation_name}); + return std::make_shared(onnx_editor::EditorNode{node_index}, m_editor); } return nullptr; } @@ -83,6 +84,36 @@ Place::Ptr InputModelONNX::get_place_by_operation_name_and_output_port(const std return nullptr; } +void InputModelONNX::set_name_for_tensor(Place::Ptr tensor, const std::string& new_name) { + const auto onnx_tensor = std::dynamic_pointer_cast(tensor); + FRONT_END_GENERAL_CHECK(onnx_tensor, __FUNCTION__, " expects a pointer to place of ONNX tensor type."); + onnx_tensor->set_name(new_name); +} + +void InputModelONNX::set_name_for_operation(Place::Ptr operation, const std::string& new_name) { + const auto onnx_operation = std::dynamic_pointer_cast(operation); + FRONT_END_GENERAL_CHECK(onnx_operation, __FUNCTION__, " expects a pointer to place of ONNX operation type."); + onnx_operation->set_name(new_name); +} + +void InputModelONNX::free_name_for_operation(const std::string& name) { + m_editor->clear_nodes_name(name); +} + +void InputModelONNX::set_name_for_dimension(Place::Ptr tensor, size_t shape_dim_index, const std::string& dim_name) { + const auto onnx_tensor = std::dynamic_pointer_cast(tensor); + FRONT_END_GENERAL_CHECK(onnx_tensor, __FUNCTION__, " expects a pointer to place of ONNX tensor type."); + onnx_tensor->set_name_for_dimension(shape_dim_index, dim_name); +} + +void InputModelONNX::add_name_for_tensor(Place::Ptr, const std::string&) { + FRONT_END_THROW("Method add_name_for_tensor is not applicable for ONNX model. ONNX tensor has just one name."); +} + +void InputModelONNX::free_name_for_tensor(const std::string&) { + FRONT_END_THROW("Method free_name_for_tensor is not applicable for ONNX model. ONNX tensor name is an identifier."); +} + void InputModelONNX::set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) { std::map m; m[place->get_names()[0]] = shape; diff --git a/ngraph/frontend/onnx/frontend/src/input_model.hpp b/ngraph/frontend/onnx/frontend/src/input_model.hpp index 4d54ab95167..59bf6cb2e53 100644 --- a/ngraph/frontend/onnx/frontend/src/input_model.hpp +++ b/ngraph/frontend/onnx/frontend/src/input_model.hpp @@ -30,6 +30,17 @@ public: int input_port_index) override; Place::Ptr get_place_by_operation_name_and_output_port(const std::string& operation_name, int output_port_index) override; + + void set_name_for_tensor(Place::Ptr tensor, const std::string& new_name) override; + void set_name_for_operation(Place::Ptr operation, const std::string& new_name) override; + void free_name_for_operation(const std::string& name) override; + void set_name_for_dimension(Place::Ptr place, size_t shape_dim_index, const std::string& dim_name) override; + + /// \brief Not applicable for ONNX model. Throws immediately + void add_name_for_tensor(Place::Ptr tensor, const std::string& new_name) override; + /// \brief Not applicable for ONNX model. Throws immediately + void free_name_for_tensor(const std::string& name) override; + void set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) override; ngraph::PartialShape get_partial_shape(Place::Ptr place) const override; void set_element_type(Place::Ptr place, const ngraph::element::Type& type) override; @@ -45,7 +56,5 @@ public: private: std::shared_ptr m_editor; }; - } // namespace frontend - } // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/place.cpp b/ngraph/frontend/onnx/frontend/src/place.cpp index 15a6bb74172..c8f20e65add 100644 --- a/ngraph/frontend/onnx/frontend/src/place.cpp +++ b/ngraph/frontend/onnx/frontend/src/place.cpp @@ -181,6 +181,17 @@ std::vector PlaceTensorONNX::get_consuming_operations() const { return consuming_ops; } +void PlaceTensorONNX::set_name(const std::string& new_name) { + if (m_name == new_name) + return; + m_editor->set_tensor_name(m_name, new_name); + m_name = new_name; +} + +void PlaceTensorONNX::set_name_for_dimension(size_t shape_dim_index, const std::string& dim_name) { + m_editor->set_name_for_dimension(m_name, shape_dim_index, dim_name); +} + PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr editor) : m_node{node}, m_editor{std::move(editor)} {} @@ -193,7 +204,7 @@ std::vector PlaceOpONNX::get_names() const { return {m_node.m_node_name}; } -onnx_editor::EditorNode PlaceOpONNX::get_editor_node() const { +const onnx_editor::EditorNode& PlaceOpONNX::get_editor_node() const { return m_node; } @@ -378,3 +389,8 @@ bool PlaceOpONNX::is_input() const { bool PlaceOpONNX::is_output() const { return false; } + +void PlaceOpONNX::set_name(const std::string& new_name) { + m_editor->set_node_name(m_node, new_name); + m_node.m_node_name = new_name; +} diff --git a/ngraph/frontend/onnx/frontend/src/place.hpp b/ngraph/frontend/onnx/frontend/src/place.hpp index 82d961998a4..a9d7b2a78dd 100644 --- a/ngraph/frontend/onnx/frontend/src/place.hpp +++ b/ngraph/frontend/onnx/frontend/src/place.hpp @@ -73,6 +73,9 @@ public: bool is_equal_data(Place::Ptr another) const override; std::vector get_consuming_operations() const override; + void set_name(const std::string& new_name); + void set_name_for_dimension(size_t shape_dim_index, const std::string& dim_name); + private: std::string m_name; std::shared_ptr m_editor; @@ -85,7 +88,8 @@ public: std::vector get_names() const override; // internal usage - onnx_editor::EditorNode get_editor_node() const; + const onnx_editor::EditorNode& get_editor_node() const; + void set_name(const std::string& new_name); // external usage Place::Ptr get_output_port() const override; @@ -122,5 +126,4 @@ private: std::shared_ptr m_editor; }; } // namespace frontend - } // namespace ngraph diff --git a/runtime/bindings/python/src/compatibility/pyngraph/frontend/inputmodel.cpp b/runtime/bindings/python/src/compatibility/pyngraph/frontend/inputmodel.cpp index a8cb5f30557..94e5628a917 100644 --- a/runtime/bindings/python/src/compatibility/pyngraph/frontend/inputmodel.cpp +++ b/runtime/bindings/python/src/compatibility/pyngraph/frontend/inputmodel.cpp @@ -181,7 +181,7 @@ void regclass_pyngraph_InputModel(py::module m) { place : Place Model's place. - shapeDimIndex : int + dimIndex : int Dimension index. dimName : str diff --git a/runtime/bindings/python/tests_compatibility/test_frontend/test_frontend_onnx_editor.py b/runtime/bindings/python/tests_compatibility/test_frontend/test_frontend_onnx_editor.py index cd8bdabd3e5..a52260909bf 100644 --- a/runtime/bindings/python/tests_compatibility/test_frontend/test_frontend_onnx_editor.py +++ b/runtime/bindings/python/tests_compatibility/test_frontend/test_frontend_onnx_editor.py @@ -224,6 +224,33 @@ def create_test_onnx_models(): models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer", opset_imports=[onnx.helper.make_opsetid("", 13)]) + # test place names model + add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"]) + sub = onnx.helper.make_node("Sub", inputs=["in1", "in2"], outputs=["sub_out"]) + split = onnx.helper.make_node("Split", inputs=["add_out"], outputs=["out1", "out2"], + name="split1", axis=0) + mul = onnx.helper.make_node("Mul", inputs=["one_const", "sub_out"], outputs=["out3"]) + + input_tensors = [ + make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)), + ] + output_tensors = [ + make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)), + make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)), + make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)), + ] + value_infos = [ + make_tensor_value_info("sub_out", onnx.TensorProto.FLOAT, (2, 2)), + ] + initializers = [ + onnx.helper.make_tensor("one_const", 1, [1], [1]) + ] + graph = make_graph([add, sub, split, mul], "test_graph", input_tensors, output_tensors, + value_info=value_infos, initializer=initializers) + models["test_place_names.onnx"] = make_model(graph, producer_name="ONNX Importer", + opset_imports=[onnx.helper.make_opsetid("", 13)]) + return models @@ -1033,3 +1060,148 @@ def test_get_place_by_operation_name_and_output_port(): place2 = sp_out1_tensor.get_producing_operation().get_output_port(outputPortIndex=0) assert place1.is_equal(place2) + + +def test_not_supported_methods(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + model = fe.load("test_place_names.onnx") + tensor = model.get_place_by_tensor_name(tensorName="add_out") + + with pytest.raises(Exception) as e: + model.add_name_for_tensor(tensor=tensor, newName="new_name") + assert "not applicable for ONNX model" in str(e) + + with pytest.raises(Exception) as e: + model.free_name_for_tensor("add_out") + assert "not applicable for ONNX model" in str(e) + + +def test_set_name_for_tensor(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + model = fe.load("test_place_names.onnx") + old_name = "add_out" + new_name = "add_out_new" + + tensor = model.get_place_by_tensor_name(tensorName=old_name) + + # ignore rename to own name (expect no exception) + model.set_name_for_tensor(tensor=tensor, newName=old_name) + + with pytest.raises(Exception) as e: + model.set_name_for_tensor(tensor=tensor, newName="") + assert "name must not be empty" in str(e) + + # ONNX model stores tensor info separately for inputs, outputs and between nodes tensors + with pytest.raises(Exception) as e: + model.set_name_for_tensor(tensor=tensor, newName="in1") + assert "already used by another tensor" in str(e) + with pytest.raises(Exception) as e: + model.set_name_for_tensor(tensor=tensor, newName="out1") + assert "already used by another tensor" in str(e) + with pytest.raises(Exception) as e: + model.set_name_for_tensor(tensor=tensor, newName="sub_out") + assert "already used by another tensor" in str(e) + + # actual rename + model.set_name_for_tensor(tensor=tensor, newName=new_name) + + new_tensor = model.get_place_by_tensor_name(tensorName=new_name) + assert new_tensor + assert new_tensor.is_equal(tensor) # previous Place object holds the handle + + old_tensor = model.get_place_by_tensor_name(tensorName=old_name) + assert old_tensor is None + + +def test_set_name_for_operation_with_name(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + model = fe.load("test_place_names.onnx") + old_name = "split1" + new_name = "split1_new" + + operation = model.get_place_by_operation_name(operationName=old_name) + + # ignore rename to own name (expect no exception) + model.set_name_for_operation(operation=operation, newName=old_name) + + # actual rename + model.set_name_for_operation(operation=operation, newName=new_name) + + new_operation = model.get_place_by_operation_name(operationName=new_name) + assert new_operation + assert new_operation.is_equal(operation) # previous Place object holds the handle + + # Below test passes for models with unique operation names, what is not required by ONNX standard + # If there were more that one nodes with "split1" name, this test would fail. + old_operation = model.get_place_by_operation_name(operationName=old_name) + assert old_operation is None + + +def test_set_name_for_operation_without_name(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + model = fe.load("test_place_names.onnx") + output_name = "add_out" + new_name = "Add_new" + + operation = model.get_place_by_tensor_name(tensorName=output_name).get_producing_operation() + # assure the test is performed on node with empty name + assert not operation.get_names() or len(operation.get_names()) == 0 or not operation.get_names()[0] + + # actual rename + model.set_name_for_operation(operation=operation, newName=new_name) + + new_operation = model.get_place_by_tensor_name(tensorName=output_name).get_producing_operation() + assert new_operation + assert new_operation.is_equal(operation) # previous Place object holds the handle + + +def test_free_name_for_operation(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + model = fe.load("test_place_names.onnx") + name = "split1" + + # assure non existent names are ignored (expect no exception) + model.free_name_for_operation("non existent name") + + split1 = model.get_place_by_operation_name(operationName=name) + assert split1 + model.free_name_for_operation(name) + operation = model.get_place_by_operation_name(operationName=name) + assert not operation + + new_split1 = model.get_place_by_tensor_name(tensorName="out1").get_producing_operation() + assert split1.is_equal(new_split1) + + +def test_set_name_for_dimension(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + model = fe.load("test_place_names.onnx") + dim_name = "batch_size" + + input1 = model.get_place_by_tensor_name(tensorName="in1") + model.set_name_for_dimension(input1, 0, dim_name) + assert model.get_partial_shape(input1) == PartialShape([-1, 2]) + + output1 = model.get_place_by_tensor_name(tensorName="out1") + model.set_name_for_dimension(output1, 1, dim_name) + assert model.get_partial_shape(output1) == PartialShape([1, -1]) + + # sub_output rank is 2 so setting dim_name at index 3 extends its rank to 4 + sub_output = model.get_place_by_tensor_name(tensorName="sub_out") + model.set_name_for_dimension(sub_output, 3, dim_name) + assert model.get_partial_shape(sub_output) == PartialShape([2, 2, -1, -1]) + + with pytest.raises(Exception) as e: + model.set_name_for_dimension(input1, 0, "") + assert "name must not be empty" in str(e) + + one_const = model.get_place_by_tensor_name(tensorName="one_const") + with pytest.raises(Exception) as e: + model.set_name_for_dimension(one_const, 0, dim_name) + assert "ONNX initializer shape dimension cannot be dynamic." in str(e)