[ONNX FE] Add implementation of Frontend API methods for naming and annotation (#8026)

This commit is contained in:
Tomasz Jankowski 2021-11-02 09:35:04 +01:00 committed by GitHub
parent 75c6af7af5
commit de4ceba375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 402 additions and 21 deletions

View File

@ -282,11 +282,12 @@ public:
return m_rt_info; return m_rt_info;
} }
private:
Function(const Function&) = delete; Function(const Function&) = delete;
Function(const Function&&) = delete; Function(Function&&) = delete;
Function& operator=(const Function&) = delete; Function& operator=(const Function&) = delete;
Function& operator=(Function&&) = delete;
private:
/// \brief Depending on the options selected, /// \brief Depending on the options selected,
/// checks all the Parameter/Variables are registered in the list of Function /// checks all the Parameter/Variables are registered in the list of Function
/// parameters/variables or finds all Parameters/Variables in a function and registers them. /// parameters/variables or finds all Parameters/Variables in a function and registers them.

View File

@ -93,12 +93,12 @@ public:
///// Naming and annotation ///// ///// Naming and annotation /////
/// \brief Sets name for tensor. Overwrites existing names of this place /// \brief Sets name for tensor. Overwrites existing names of this place
/// \param operation Tensor place /// \param tensor Tensor place
/// \param new_name New name for this tensor /// \param new_name New name for this tensor
virtual void set_name_for_tensor(Place::Ptr tensor, const std::string& new_name); virtual void set_name_for_tensor(Place::Ptr tensor, const std::string& new_name);
/// \brief Adds new name for tensor /// \brief Adds new name for tensor
/// \param operation Tensor place /// \param tensor Tensor place
/// \param new_name New name to be added to this place /// \param new_name New name to be added to this place
virtual void add_name_for_tensor(Place::Ptr tensor, const std::string& new_name); virtual void add_name_for_tensor(Place::Ptr tensor, const std::string& new_name);

View File

@ -57,6 +57,16 @@ TensorProto* find_graph_initializer(GraphProto& graph, const std::string& name)
return nullptr; return nullptr;
} }
ValueInfoProto* find_graph_value_info(GraphProto& graph, const std::string& name) {
for (int i = 0; i < graph.value_info_size(); ++i) {
auto value_info = graph.mutable_value_info(i);
if (value_info->name() == name) {
return value_info;
}
}
return nullptr;
}
void modify_input_type(ValueInfoProto& onnx_input, const element::Type_t elem_type) { void modify_input_type(ValueInfoProto& onnx_input, const element::Type_t elem_type) {
if (!onnx_input.has_type()) { if (!onnx_input.has_type()) {
throw ngraph_error("The input is malformed - it doesn't contain the 'type' field. Cannot change the " throw ngraph_error("The input is malformed - it doesn't contain the 'type' field. Cannot change the "
@ -272,12 +282,17 @@ void onnx_editor::ONNXModelEditor::set_input_shapes(const std::map<std::string,
PartialShape onnx_editor::ONNXModelEditor::get_tensor_shape(const std::string& tensor_name) const { PartialShape onnx_editor::ONNXModelEditor::get_tensor_shape(const std::string& tensor_name) const {
const ValueInfoProto* value_info = nullptr; const ValueInfoProto* value_info = nullptr;
auto* onnx_graph = m_pimpl->m_model_proto->mutable_graph(); const TensorProto* tensor = nullptr;
const auto onnx_graph = m_pimpl->m_model_proto->mutable_graph();
InferShapesAutoRelease onnx_shapes(m_pimpl->m_model_proto); InferShapesAutoRelease onnx_shapes(m_pimpl->m_model_proto);
if (const auto* input = find_graph_input(*onnx_graph, tensor_name)) { if (const auto input = find_graph_input(*onnx_graph, tensor_name)) {
value_info = input; value_info = input;
} else if (const auto* output = find_graph_output(*onnx_graph, tensor_name)) { } else if (const auto output = find_graph_output(*onnx_graph, tensor_name)) {
value_info = output; value_info = output;
} else if (const auto val_info = find_graph_value_info(*onnx_graph, tensor_name)) {
value_info = val_info;
} else if (const auto initializer = find_graph_initializer(*onnx_graph, tensor_name)) {
tensor = initializer;
} else { } else {
try { try {
onnx_shapes.infer_shapes(); onnx_shapes.infer_shapes();
@ -301,6 +316,8 @@ PartialShape onnx_editor::ONNXModelEditor::get_tensor_shape(const std::string& t
} else { } else {
return PartialShape::dynamic(); return PartialShape::dynamic();
} }
} else if (tensor) {
return PartialShape{Shape{tensor->dims().cbegin(), tensor->dims().cend()}};
} else { } else {
throw ngraph_error("The tensor: " + tensor_name + " was not found in the graph"); throw ngraph_error("The tensor: " + tensor_name + " was not found in the graph");
} }
@ -412,6 +429,105 @@ void onnx_editor::ONNXModelEditor::set_input_values(
} }
} }
void onnx_editor::ONNXModelEditor::set_tensor_name(const std::string& current_name, const std::string& new_name) {
OPENVINO_ASSERT(!new_name.empty(), "New name must not be empty.");
const auto graph = m_pimpl->m_model_proto->mutable_graph();
OPENVINO_ASSERT(!(find_graph_input(*graph, new_name) || find_graph_output(*graph, new_name) ||
find_graph_initializer(*graph, new_name) || find_graph_value_info(*graph, new_name) ||
m_pimpl->m_edge_mapper.is_correct_tensor_name(new_name)),
"The name '",
new_name,
"' is already used by another tensor.");
m_pimpl->m_is_mapper_updated = false;
// the same tensor can be multiplied in any or all of below arrays
if (const auto initializer = find_graph_initializer(*graph, current_name))
*initializer->mutable_name() = new_name;
if (const auto input = find_graph_input(*graph, current_name))
*input->mutable_name() = new_name;
if (const auto output = find_graph_output(*graph, current_name))
*output->mutable_name() = new_name;
if (const auto value_info = find_graph_value_info(*graph, current_name))
*value_info->mutable_name() = new_name;
for (size_t i = 0; i < graph->node().size(); ++i) {
const auto node = graph->mutable_node(i);
bool output_found = false;
for (size_t j = 0; j < node->output().size(); ++j)
if (node->output(j) == current_name) {
*node->mutable_output(j) = new_name;
output_found = true;
break;
}
if (output_found)
continue;
for (size_t j = 0; j < node->input().size(); ++j)
if (node->input(j) == current_name)
*node->mutable_input(j) = new_name;
}
}
void onnx_editor::ONNXModelEditor::set_node_name(const EditorNode& node, const std::string& new_name) {
const auto node_idx = m_pimpl->m_edge_mapper.get_node_index(node);
const auto graph = m_pimpl->m_model_proto->mutable_graph();
m_pimpl->m_is_mapper_updated = false;
*graph->mutable_node(node_idx)->mutable_name() = new_name;
}
void onnx_editor::ONNXModelEditor::clear_nodes_name(const std::string& name) {
const auto graph = m_pimpl->m_model_proto->mutable_graph();
m_pimpl->m_is_mapper_updated = false;
for (size_t i = 0; i < graph->node().size(); ++i) {
const auto node = graph->mutable_node(i);
if (node->has_name() && node->name() == name)
node->clear_name();
}
}
void onnx_editor::ONNXModelEditor::set_name_for_dimension(const std::string& node_name,
size_t shape_dim_index,
const std::string& dim_name) {
OPENVINO_ASSERT(!dim_name.empty(), "Dimension name must not be empty.");
const auto graph = m_pimpl->m_model_proto->mutable_graph();
OPENVINO_ASSERT(!find_graph_initializer(*graph, node_name), "ONNX initializer shape dimension cannot be dynamic.");
// the same tensor can be multiplied in any or all of below arrays
const auto input = find_graph_input(*graph, node_name);
const auto output = find_graph_output(*graph, node_name);
const auto value_info = find_graph_value_info(*graph, node_name);
OPENVINO_ASSERT(input || output || value_info, "There is no tensor named '", node_name, "' in the graph.");
const auto set_dim_param = [&shape_dim_index, &dim_name](ValueInfoProto* tensor) {
const auto shape = tensor->mutable_type()->mutable_tensor_type()->mutable_shape();
auto shape_dim_size = shape->dim_size();
for (; shape_dim_size <= shape_dim_index; ++shape_dim_size)
add_dim_to_onnx_shape(Dimension::dynamic(), *shape);
shape->mutable_dim(shape_dim_index)->set_dim_param(dim_name.c_str());
};
m_pimpl->m_is_mapper_updated = false;
if (input)
set_dim_param(input);
if (output)
set_dim_param(output);
if (value_info)
set_dim_param(value_info);
}
void onnx_editor::ONNXModelEditor::update_mapper_if_needed() const { void onnx_editor::ONNXModelEditor::update_mapper_if_needed() const {
if (!m_pimpl->m_is_mapper_updated) { if (!m_pimpl->m_is_mapper_updated) {
m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto->graph()); m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto->graph());

View File

@ -24,8 +24,6 @@ namespace onnx_editor {
/// model's input types and shapes, extract a subgraph and more. /// model's input types and shapes, extract a subgraph and more.
class ONNX_IMPORTER_API ONNXModelEditor final { class ONNX_IMPORTER_API ONNXModelEditor final {
public: public:
ONNXModelEditor() = delete;
/// \brief Creates an editor from a model file located on a storage device. The file /// \brief Creates an editor from a model file located on a storage device. The file
/// is parsed and loaded into the m_model_proto member variable. /// is parsed and loaded into the m_model_proto member variable.
/// ///
@ -71,7 +69,7 @@ public:
/// the underlying ModelProto is modified - obsolete inputs, initializers, nodes /// the underlying ModelProto is modified - obsolete inputs, initializers, nodes
/// and outputs are removed from the in-memory model. /// and outputs are removed from the in-memory model.
/// ///
/// \node Please look at the declaration of InputEdge and OutputEdge for explanation /// \note Please look at the declaration of InputEdge and OutputEdge for explanation
/// how those objects can be created. If the outputs parameter is empty /// how those objects can be created. If the outputs parameter is empty
/// this method keeps all of the original outputs of the model. /// this method keeps all of the original outputs of the model.
/// ///
@ -92,6 +90,41 @@ public:
/// overwritten. /// overwritten.
void set_input_values(const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values); void set_input_values(const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values);
/// \brief Changes the name of given tensor.
///
/// \note It changes input, output, initializer and value_info proto repeated fields as well as
/// all nodes which refer to the tensor.
///
/// \param current_name Name of tensor to be changed.
/// \param new_name New name of tensor. Must not be empty nor point to existing tensor (including self).
void set_tensor_name(const std::string& current_name, const std::string& new_name);
/// \brief Sets node's name.
///
/// \note Empty name is accepted.
///
/// \param node Handle to node.
/// \param new_name New name of the node.
void set_node_name(const EditorNode& node, const std::string& new_name);
/// \brief Removes node name for all nodes with given name.
///
/// \note Empty and not present names are accepted.
///
/// \param name Name to clear
void clear_nodes_name(const std::string& name);
/// \brief Overrides or creates name for tensor shape dimension (numeric dimension is erased).
///
/// \note It changes input, output and value_info proto repeated fields.
/// If rank of the tensor is too low the shape is expanded with dynamic dimensions so
/// the name can be set at specified position.
///
/// \param node_name Tensor name to change its shape. Must not point to initializer.
/// \param shape_dim_index Index of dimension to change.
/// \param dim_name New name of the dimension. Must not be empty.
void set_name_for_dimension(const std::string& node_name, size_t shape_dim_index, const std::string& dim_name);
/// \brief Returns a serialized ONNX model, possibly modified by the editor. /// \brief Returns a serialized ONNX model, possibly modified by the editor.
std::string model_string() const; std::string model_string() const;
@ -126,7 +159,7 @@ public:
/// ///
std::string get_target_tensor_name(const OutputEdge& edge) const; std::string get_target_tensor_name(const OutputEdge& edge) const;
/// \brief Returns true if output edge is input of the model. Otherwise false. /// \brief Returns true if output edge is output of the model. Otherwise false.
bool is_output(const OutputEdge& edge) const; bool is_output(const OutputEdge& edge) const;
/// \brief Returns the path to the original model file /// \brief Returns the path to the original model file

View File

@ -119,9 +119,9 @@ struct EditorNode {
EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {} EditorNode(std::string node_name) : m_node_name{std::move(node_name)} {}
EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {} EditorNode(EditorOutput output) : m_output_name{std::move(output.m_output_name)} {}
EditorNode(const int node_index) : m_node_index{node_index} {} EditorNode(const int node_index) : m_node_index{node_index} {}
const std::string m_node_name = ""; std::string m_node_name = "";
const std::string m_output_name = ""; std::string m_output_name = "";
const int m_node_index = -1; int m_node_index = -1;
}; };
} // namespace onnx_editor } // namespace onnx_editor
} // namespace ngraph } // namespace ngraph

View File

@ -60,7 +60,8 @@ Place::Ptr InputModelONNX::get_place_by_tensor_name(const std::string& tensor_na
Place::Ptr InputModelONNX::get_place_by_operation_name(const std::string& operation_name) const { Place::Ptr InputModelONNX::get_place_by_operation_name(const std::string& operation_name) const {
if (m_editor->is_correct_and_unambiguous_node(operation_name)) { if (m_editor->is_correct_and_unambiguous_node(operation_name)) {
return std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{operation_name}, m_editor); const auto node_index = m_editor->get_node_index(onnx_editor::EditorNode{operation_name});
return std::make_shared<PlaceOpONNX>(onnx_editor::EditorNode{node_index}, m_editor);
} }
return nullptr; return nullptr;
} }
@ -83,6 +84,36 @@ Place::Ptr InputModelONNX::get_place_by_operation_name_and_output_port(const std
return nullptr; return nullptr;
} }
void InputModelONNX::set_name_for_tensor(Place::Ptr tensor, const std::string& new_name) {
const auto onnx_tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(tensor);
FRONT_END_GENERAL_CHECK(onnx_tensor, __FUNCTION__, " expects a pointer to place of ONNX tensor type.");
onnx_tensor->set_name(new_name);
}
void InputModelONNX::set_name_for_operation(Place::Ptr operation, const std::string& new_name) {
const auto onnx_operation = std::dynamic_pointer_cast<PlaceOpONNX>(operation);
FRONT_END_GENERAL_CHECK(onnx_operation, __FUNCTION__, " expects a pointer to place of ONNX operation type.");
onnx_operation->set_name(new_name);
}
void InputModelONNX::free_name_for_operation(const std::string& name) {
m_editor->clear_nodes_name(name);
}
void InputModelONNX::set_name_for_dimension(Place::Ptr tensor, size_t shape_dim_index, const std::string& dim_name) {
const auto onnx_tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(tensor);
FRONT_END_GENERAL_CHECK(onnx_tensor, __FUNCTION__, " expects a pointer to place of ONNX tensor type.");
onnx_tensor->set_name_for_dimension(shape_dim_index, dim_name);
}
void InputModelONNX::add_name_for_tensor(Place::Ptr, const std::string&) {
FRONT_END_THROW("Method add_name_for_tensor is not applicable for ONNX model. ONNX tensor has just one name.");
}
void InputModelONNX::free_name_for_tensor(const std::string&) {
FRONT_END_THROW("Method free_name_for_tensor is not applicable for ONNX model. ONNX tensor name is an identifier.");
}
void InputModelONNX::set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) { void InputModelONNX::set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) {
std::map<std::string, ngraph::PartialShape> m; std::map<std::string, ngraph::PartialShape> m;
m[place->get_names()[0]] = shape; m[place->get_names()[0]] = shape;

View File

@ -30,6 +30,17 @@ public:
int input_port_index) override; int input_port_index) override;
Place::Ptr get_place_by_operation_name_and_output_port(const std::string& operation_name, Place::Ptr get_place_by_operation_name_and_output_port(const std::string& operation_name,
int output_port_index) override; int output_port_index) override;
void set_name_for_tensor(Place::Ptr tensor, const std::string& new_name) override;
void set_name_for_operation(Place::Ptr operation, const std::string& new_name) override;
void free_name_for_operation(const std::string& name) override;
void set_name_for_dimension(Place::Ptr place, size_t shape_dim_index, const std::string& dim_name) override;
/// \brief Not applicable for ONNX model. Throws immediately
void add_name_for_tensor(Place::Ptr tensor, const std::string& new_name) override;
/// \brief Not applicable for ONNX model. Throws immediately
void free_name_for_tensor(const std::string& name) override;
void set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) override; void set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) override;
ngraph::PartialShape get_partial_shape(Place::Ptr place) const override; ngraph::PartialShape get_partial_shape(Place::Ptr place) const override;
void set_element_type(Place::Ptr place, const ngraph::element::Type& type) override; void set_element_type(Place::Ptr place, const ngraph::element::Type& type) override;
@ -45,7 +56,5 @@ public:
private: private:
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor; std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
}; };
} // namespace frontend } // namespace frontend
} // namespace ngraph } // namespace ngraph

View File

@ -181,6 +181,17 @@ std::vector<Place::Ptr> PlaceTensorONNX::get_consuming_operations() const {
return consuming_ops; return consuming_ops;
} }
void PlaceTensorONNX::set_name(const std::string& new_name) {
if (m_name == new_name)
return;
m_editor->set_tensor_name(m_name, new_name);
m_name = new_name;
}
void PlaceTensorONNX::set_name_for_dimension(size_t shape_dim_index, const std::string& dim_name) {
m_editor->set_name_for_dimension(m_name, shape_dim_index, dim_name);
}
PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor) PlaceOpONNX::PlaceOpONNX(const onnx_editor::EditorNode& node, std::shared_ptr<onnx_editor::ONNXModelEditor> editor)
: m_node{node}, : m_node{node},
m_editor{std::move(editor)} {} m_editor{std::move(editor)} {}
@ -193,7 +204,7 @@ std::vector<std::string> PlaceOpONNX::get_names() const {
return {m_node.m_node_name}; return {m_node.m_node_name};
} }
onnx_editor::EditorNode PlaceOpONNX::get_editor_node() const { const onnx_editor::EditorNode& PlaceOpONNX::get_editor_node() const {
return m_node; return m_node;
} }
@ -378,3 +389,8 @@ bool PlaceOpONNX::is_input() const {
bool PlaceOpONNX::is_output() const { bool PlaceOpONNX::is_output() const {
return false; return false;
} }
void PlaceOpONNX::set_name(const std::string& new_name) {
m_editor->set_node_name(m_node, new_name);
m_node.m_node_name = new_name;
}

View File

@ -73,6 +73,9 @@ public:
bool is_equal_data(Place::Ptr another) const override; bool is_equal_data(Place::Ptr another) const override;
std::vector<Place::Ptr> get_consuming_operations() const override; std::vector<Place::Ptr> get_consuming_operations() const override;
void set_name(const std::string& new_name);
void set_name_for_dimension(size_t shape_dim_index, const std::string& dim_name);
private: private:
std::string m_name; std::string m_name;
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor; std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
@ -85,7 +88,8 @@ public:
std::vector<std::string> get_names() const override; std::vector<std::string> get_names() const override;
// internal usage // internal usage
onnx_editor::EditorNode get_editor_node() const; const onnx_editor::EditorNode& get_editor_node() const;
void set_name(const std::string& new_name);
// external usage // external usage
Place::Ptr get_output_port() const override; Place::Ptr get_output_port() const override;
@ -122,5 +126,4 @@ private:
std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor; std::shared_ptr<onnx_editor::ONNXModelEditor> m_editor;
}; };
} // namespace frontend } // namespace frontend
} // namespace ngraph } // namespace ngraph

View File

@ -181,7 +181,7 @@ void regclass_pyngraph_InputModel(py::module m) {
place : Place place : Place
Model's place. Model's place.
shapeDimIndex : int dimIndex : int
Dimension index. Dimension index.
dimName : str dimName : str

View File

@ -224,6 +224,33 @@ def create_test_onnx_models():
models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer", models["test_partial_shape.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)]) opset_imports=[onnx.helper.make_opsetid("", 13)])
# test place names model
add = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["add_out"])
sub = onnx.helper.make_node("Sub", inputs=["in1", "in2"], outputs=["sub_out"])
split = onnx.helper.make_node("Split", inputs=["add_out"], outputs=["out1", "out2"],
name="split1", axis=0)
mul = onnx.helper.make_node("Mul", inputs=["one_const", "sub_out"], outputs=["out3"])
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (1, 2)),
make_tensor_value_info("out3", onnx.TensorProto.FLOAT, (2, 2)),
]
value_infos = [
make_tensor_value_info("sub_out", onnx.TensorProto.FLOAT, (2, 2)),
]
initializers = [
onnx.helper.make_tensor("one_const", 1, [1], [1])
]
graph = make_graph([add, sub, split, mul], "test_graph", input_tensors, output_tensors,
value_info=value_infos, initializer=initializers)
models["test_place_names.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
return models return models
@ -1033,3 +1060,148 @@ def test_get_place_by_operation_name_and_output_port():
place2 = sp_out1_tensor.get_producing_operation().get_output_port(outputPortIndex=0) place2 = sp_out1_tensor.get_producing_operation().get_output_port(outputPortIndex=0)
assert place1.is_equal(place2) assert place1.is_equal(place2)
def test_not_supported_methods():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("test_place_names.onnx")
tensor = model.get_place_by_tensor_name(tensorName="add_out")
with pytest.raises(Exception) as e:
model.add_name_for_tensor(tensor=tensor, newName="new_name")
assert "not applicable for ONNX model" in str(e)
with pytest.raises(Exception) as e:
model.free_name_for_tensor("add_out")
assert "not applicable for ONNX model" in str(e)
def test_set_name_for_tensor():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("test_place_names.onnx")
old_name = "add_out"
new_name = "add_out_new"
tensor = model.get_place_by_tensor_name(tensorName=old_name)
# ignore rename to own name (expect no exception)
model.set_name_for_tensor(tensor=tensor, newName=old_name)
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="")
assert "name must not be empty" in str(e)
# ONNX model stores tensor info separately for inputs, outputs and between nodes tensors
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="in1")
assert "already used by another tensor" in str(e)
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="out1")
assert "already used by another tensor" in str(e)
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="sub_out")
assert "already used by another tensor" in str(e)
# actual rename
model.set_name_for_tensor(tensor=tensor, newName=new_name)
new_tensor = model.get_place_by_tensor_name(tensorName=new_name)
assert new_tensor
assert new_tensor.is_equal(tensor) # previous Place object holds the handle
old_tensor = model.get_place_by_tensor_name(tensorName=old_name)
assert old_tensor is None
def test_set_name_for_operation_with_name():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("test_place_names.onnx")
old_name = "split1"
new_name = "split1_new"
operation = model.get_place_by_operation_name(operationName=old_name)
# ignore rename to own name (expect no exception)
model.set_name_for_operation(operation=operation, newName=old_name)
# actual rename
model.set_name_for_operation(operation=operation, newName=new_name)
new_operation = model.get_place_by_operation_name(operationName=new_name)
assert new_operation
assert new_operation.is_equal(operation) # previous Place object holds the handle
# Below test passes for models with unique operation names, what is not required by ONNX standard
# If there were more that one nodes with "split1" name, this test would fail.
old_operation = model.get_place_by_operation_name(operationName=old_name)
assert old_operation is None
def test_set_name_for_operation_without_name():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("test_place_names.onnx")
output_name = "add_out"
new_name = "Add_new"
operation = model.get_place_by_tensor_name(tensorName=output_name).get_producing_operation()
# assure the test is performed on node with empty name
assert not operation.get_names() or len(operation.get_names()) == 0 or not operation.get_names()[0]
# actual rename
model.set_name_for_operation(operation=operation, newName=new_name)
new_operation = model.get_place_by_tensor_name(tensorName=output_name).get_producing_operation()
assert new_operation
assert new_operation.is_equal(operation) # previous Place object holds the handle
def test_free_name_for_operation():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("test_place_names.onnx")
name = "split1"
# assure non existent names are ignored (expect no exception)
model.free_name_for_operation("non existent name")
split1 = model.get_place_by_operation_name(operationName=name)
assert split1
model.free_name_for_operation(name)
operation = model.get_place_by_operation_name(operationName=name)
assert not operation
new_split1 = model.get_place_by_tensor_name(tensorName="out1").get_producing_operation()
assert split1.is_equal(new_split1)
def test_set_name_for_dimension():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("test_place_names.onnx")
dim_name = "batch_size"
input1 = model.get_place_by_tensor_name(tensorName="in1")
model.set_name_for_dimension(input1, 0, dim_name)
assert model.get_partial_shape(input1) == PartialShape([-1, 2])
output1 = model.get_place_by_tensor_name(tensorName="out1")
model.set_name_for_dimension(output1, 1, dim_name)
assert model.get_partial_shape(output1) == PartialShape([1, -1])
# sub_output rank is 2 so setting dim_name at index 3 extends its rank to 4
sub_output = model.get_place_by_tensor_name(tensorName="sub_out")
model.set_name_for_dimension(sub_output, 3, dim_name)
assert model.get_partial_shape(sub_output) == PartialShape([2, 2, -1, -1])
with pytest.raises(Exception) as e:
model.set_name_for_dimension(input1, 0, "")
assert "name must not be empty" in str(e)
one_const = model.get_place_by_tensor_name(tensorName="one_const")
with pytest.raises(Exception) as e:
model.set_name_for_dimension(one_const, 0, dim_name)
assert "ONNX initializer shape dimension cannot be dynamic." in str(e)