[PYTHON] Rename arg names python api (#9209)

This commit is contained in:
Piotr Szmelczynski 2021-12-20 14:31:48 +01:00 committed by GitHub
parent 249e1266fb
commit 513867e168
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 262 additions and 257 deletions

View File

@ -38,7 +38,7 @@ void regclass_CompiledModel(py::module m) {
},
py::arg("inputs"));
cls.def("export_model", &ov::runtime::CompiledModel::export_model, py::arg("network_model"));
cls.def("export_model", &ov::runtime::CompiledModel::export_model, py::arg("model_stream"));
cls.def(
"get_config",
@ -64,7 +64,7 @@ void regclass_CompiledModel(py::module m) {
cls.def(
"input",
(ov::Output<const ov::Node>(ov::runtime::CompiledModel::*)(size_t) const) & ov::runtime::CompiledModel::input,
py::arg("i"));
py::arg("index"));
cls.def("input",
(ov::Output<const ov::Node>(ov::runtime::CompiledModel::*)(const std::string&) const) &
@ -79,7 +79,7 @@ void regclass_CompiledModel(py::module m) {
cls.def(
"output",
(ov::Output<const ov::Node>(ov::runtime::CompiledModel::*)(size_t) const) & ov::runtime::CompiledModel::output,
py::arg("i"));
py::arg("index"));
cls.def("output",
(ov::Output<const ov::Node>(ov::runtime::CompiledModel::*)(const std::string&) const) &

View File

@ -48,7 +48,7 @@ void regclass_Core(py::module m) {
py::arg("device_name"),
py::arg("config") = py::dict());
cls.def("get_versions", &ov::runtime::Core::get_versions);
cls.def("get_versions", &ov::runtime::Core::get_versions, py::arg("device_name"));
cls.def(
"read_model",

View File

@ -153,7 +153,7 @@ void regclass_InferRequest(py::module m) {
[](InferRequestWrapper& self, size_t idx) {
return self._request.get_input_tensor(idx);
},
py::arg("idx"));
py::arg("index"));
cls.def("get_input_tensor", [](InferRequestWrapper& self) {
return self._request.get_input_tensor();
@ -164,7 +164,7 @@ void regclass_InferRequest(py::module m) {
[](InferRequestWrapper& self, size_t idx) {
return self._request.get_output_tensor(idx);
},
py::arg("idx"));
py::arg("index"));
cls.def("get_output_tensor", [](InferRequestWrapper& self) {
return self._request.get_output_tensor();
@ -199,7 +199,7 @@ void regclass_InferRequest(py::module m) {
[](InferRequestWrapper& self, size_t idx, const ov::runtime::Tensor& tensor) {
self._request.set_input_tensor(idx, tensor);
},
py::arg("idx"),
py::arg("index"),
py::arg("tensor"));
cls.def(
@ -214,7 +214,7 @@ void regclass_InferRequest(py::module m) {
[](InferRequestWrapper& self, size_t idx, const ov::runtime::Tensor& tensor) {
self._request.set_output_tensor(idx, tensor);
},
py::arg("idx"),
py::arg("index"),
py::arg("tensor"));
cls.def(

View File

@ -19,14 +19,14 @@ void regclass_frontend_InputModel(py::module m) {
im.def("get_place_by_tensor_name",
&ov::frontend::InputModel::get_place_by_tensor_name,
py::arg("tensorName"),
py::arg("tensor_name"),
R"(
Returns a tensor place by a tensor name following framework conventions, or
nullptr if a tensor with this name doesn't exist.
Parameters
----------
tensorName : str
tensor_name : str
Name of tensor.
Returns
@ -37,14 +37,14 @@ void regclass_frontend_InputModel(py::module m) {
im.def("get_place_by_operation_name",
&ov::frontend::InputModel::get_place_by_operation_name,
py::arg("operationName"),
py::arg("operation_name"),
R"(
Returns an operation place by an operation name following framework conventions, or
nullptr if an operation with this name doesn't exist.
Parameters
----------
operationName : str
operation_name : str
Name of operation.
Returns
@ -55,17 +55,17 @@ void regclass_frontend_InputModel(py::module m) {
im.def("get_place_by_operation_name_and_input_port",
&ov::frontend::InputModel::get_place_by_operation_name_and_input_port,
py::arg("operationName"),
py::arg("inputPortIndex"),
py::arg("operation_name"),
py::arg("input_port_index"),
R"(
Returns an input port place by operation name and appropriate port index.
Parameters
----------
operationName : str
operation_name : str
Name of operation.
inputPortIndex : int
input_port_index : int
Index of input port for this operation.
Returns
@ -76,17 +76,17 @@ void regclass_frontend_InputModel(py::module m) {
im.def("get_place_by_operation_name_and_output_port",
&ov::frontend::InputModel::get_place_by_operation_name_and_output_port,
py::arg("operationName"),
py::arg("outputPortIndex"),
py::arg("operation_name"),
py::arg("output_port_index"),
R"(
Returns an output port place by operation name and appropriate port index.
Parameters
----------
operationName : str
operation_name : str
Name of operation.
outputPortIndex : int
output_port_index : int
Index of output port for this operation.
Returns
@ -98,7 +98,7 @@ void regclass_frontend_InputModel(py::module m) {
im.def("set_name_for_tensor",
&ov::frontend::InputModel::set_name_for_tensor,
py::arg("tensor"),
py::arg("newName"),
py::arg("new_name"),
R"(
Sets name for tensor. Overwrites existing names of this place.
@ -107,14 +107,14 @@ void regclass_frontend_InputModel(py::module m) {
tensor : Place
Tensor place.
newName : str
new_name : str
New name for this tensor.
)");
im.def("add_name_for_tensor",
&ov::frontend::InputModel::add_name_for_tensor,
py::arg("tensor"),
py::arg("newName"),
py::arg("new_name"),
R"(
Adds new name for tensor
@ -123,14 +123,14 @@ void regclass_frontend_InputModel(py::module m) {
tensor : Place
Tensor place.
newName : str
new_name : str
New name to be added to this place.
)");
im.def("set_name_for_operation",
&ov::frontend::InputModel::set_name_for_operation,
py::arg("operation"),
py::arg("newName"),
py::arg("new_name"),
R"(
Adds new name for tensor.
@ -139,7 +139,7 @@ void regclass_frontend_InputModel(py::module m) {
operation : Place
Operation place.
newName : str
new_name : str
New name for this operation.
)");
@ -170,8 +170,8 @@ void regclass_frontend_InputModel(py::module m) {
im.def("set_name_for_dimension",
&ov::frontend::InputModel::set_name_for_dimension,
py::arg("place"),
py::arg("dimIndex"),
py::arg("dimName"),
py::arg("dim_index"),
py::arg("dim_name"),
R"(
Set name for a particular dimension of a place (e.g. batch dimension).
@ -180,17 +180,17 @@ void regclass_frontend_InputModel(py::module m) {
place : Place
Model's place.
dimIndex : int
dim_index : int
Dimension index.
dimName : str
dim_name : str
Name to assign on this dimension.
)");
im.def("cut_and_add_new_input",
&ov::frontend::InputModel::cut_and_add_new_input,
py::arg("place"),
py::arg("newName") = std::string(),
py::arg("new_name") = std::string(),
R"(
Cut immediately before this place and assign this place as new input; prune
all nodes that don't contribute to any output.
@ -200,14 +200,14 @@ void regclass_frontend_InputModel(py::module m) {
place : Place
New place to be assigned as input.
newNameOptional : str
new_name_optional : str
Optional new name assigned to this input place.
)");
im.def("cut_and_add_new_output",
&ov::frontend::InputModel::cut_and_add_new_output,
py::arg("place"),
py::arg("newName") = std::string(),
py::arg("new_name") = std::string(),
R"(
Cut immediately before this place and assign this place as new output; prune
all nodes that don't contribute to any output.
@ -217,7 +217,7 @@ void regclass_frontend_InputModel(py::module m) {
place : Place
New place to be assigned as output.
newNameOptional : str
new_name_optional : str
Optional new name assigned to this output place.
)");

View File

@ -104,17 +104,17 @@ void regclass_frontend_Place(py::module m) {
}
}
},
py::arg("outputName") = py::none(),
py::arg("outputPortIndex") = py::none(),
py::arg("output_name") = py::none(),
py::arg("output_port_index") = py::none(),
R"(
Returns references to all operation nodes that consume data from this place for specified output port.
Note: It can be called for any kind of graph place searching for the first consuming operations.
Parameters
----------
outputName : str
output_name : str
Name of output port group. May not be set if node has one output port group.
outputPortIndex : int
output_port_index : int
If place is an operational node it specifies which output port should be considered
May not be set if node has only one output port.
@ -141,17 +141,17 @@ void regclass_frontend_Place(py::module m) {
}
}
},
py::arg("outputName") = py::none(),
py::arg("outputPortIndex") = py::none(),
py::arg("output_name") = py::none(),
py::arg("output_port_index") = py::none(),
R"(
Returns a tensor place that gets data from this place; applicable for operations,
output ports and output edges.
Parameters
----------
outputName : str
output_name : str
Name of output port group. May not be set if node has one output port group.
outputPortIndex : int
output_port_index : int
Output port index if the current place is an operation node and has multiple output ports.
May not be set if place has only one output port.
@ -179,16 +179,16 @@ void regclass_frontend_Place(py::module m) {
}
}
},
py::arg("inputName") = py::none(),
py::arg("inputPortIndex") = py::none(),
py::arg("input_name") = py::none(),
py::arg("input_port_index") = py::none(),
R"(
Get an operation node place that immediately produces data for this place.
Parameters
----------
inputName : str
input_name : str
Name of port group. May not be set if node has one input port group.
inputPortIndex : int
input_port_index : int
If a given place is itself an operation node, this specifies a port index.
May not be set if place has only one input port.
@ -226,17 +226,17 @@ void regclass_frontend_Place(py::module m) {
}
}
},
py::arg("inputName") = py::none(),
py::arg("inputPortIndex") = py::none(),
py::arg("input_name") = py::none(),
py::arg("input_port_index") = py::none(),
R"(
For operation node returns reference to an input port with specified name and index.
Parameters
----------
inputName : str
input_name : str
Name of port group. May not be set if node has one input port group.
inputPortIndex : int
input_port_index : int
Input port index in a group. May not be set if node has one input port in a group.
Returns
@ -262,17 +262,17 @@ void regclass_frontend_Place(py::module m) {
}
}
},
py::arg("outputName") = py::none(),
py::arg("outputPortIndex") = py::none(),
py::arg("output_name") = py::none(),
py::arg("output_port_index") = py::none(),
R"(
For operation node returns reference to an output port with specified name and index.
Parameters
----------
outputName : str
output_name : str
Name of output port group. May not be set if node has one output port group.
outputPortIndex : int
output_port_index : int
Output port index. May not be set if node has one output port in a group.
Returns
@ -309,17 +309,17 @@ void regclass_frontend_Place(py::module m) {
}
}
},
py::arg("inputName") = py::none(),
py::arg("inputPortIndex") = py::none(),
py::arg("input_name") = py::none(),
py::arg("input_port_index") = py::none(),
R"(
Returns a tensor place that supplies data for this place; applicable for operations,
input ports and input edges.
Parameters
----------
inputName : str
input_name : str
Name of port group. May not be set if node has one input port group.
inputPortIndex : int
input_port_index : int
Input port index for operational node. May not be specified if place has only one input port.
Returns

View File

@ -137,14 +137,14 @@ void regclass_graph_Dimension(py::module m) {
)");
dim.def("compatible",
&ov::Dimension::compatible,
py::arg("d"),
py::arg("dim"),
R"(
Check whether this dimension is capable of being merged
with the argument dimension.
Parameters
----------
d : Dimension
dim : Dimension
The dimension to compare this dimension with.
Returns
@ -154,7 +154,7 @@ void regclass_graph_Dimension(py::module m) {
)");
dim.def("relaxes",
&ov::Dimension::relaxes,
py::arg("d"),
py::arg("dim"),
R"(
Check whether this dimension is a relaxation of the argument.
This dimension relaxes (or is a relaxation of) d if:
@ -166,7 +166,7 @@ void regclass_graph_Dimension(py::module m) {
Parameters
----------
d : Dimension
dim : Dimension
The dimension to compare this dimension with.
Returns
@ -176,7 +176,7 @@ void regclass_graph_Dimension(py::module m) {
)");
dim.def("refines",
&ov::Dimension::refines,
py::arg("d"),
py::arg("dim"),
R"(
Check whether this dimension is a refinement of the argument.
This dimension refines (or is a refinement of) d if:
@ -188,7 +188,7 @@ void regclass_graph_Dimension(py::module m) {
Parameters
----------
d : Dimension
dim : Dimension
The dimension to compare this dimension with.
Returns

View File

@ -374,13 +374,13 @@ void regclass_graph_Model(py::module m) {
)");
function.def("get_output_op",
&ov::Model::get_output_op,
py::arg("i"),
py::arg("index"),
R"(
Return the op that generates output i
Parameters
----------
i : int
index : int
output index
Returns
@ -390,13 +390,13 @@ void regclass_graph_Model(py::module m) {
)");
function.def("get_output_element_type",
&ov::Model::get_output_element_type,
py::arg("i"),
py::arg("index"),
R"(
Return the element type of output i
Parameters
----------
i : int
index : int
output index
Returns
@ -406,13 +406,13 @@ void regclass_graph_Model(py::module m) {
)");
function.def("get_output_shape",
&ov::Model::get_output_shape,
py::arg("i"),
py::arg("index"),
R"(
Return the shape of element i
Parameters
----------
i : int
index : int
element index
Returns
@ -422,13 +422,13 @@ void regclass_graph_Model(py::module m) {
)");
function.def("get_output_partial_shape",
&ov::Model::get_output_partial_shape,
py::arg("i"),
py::arg("index"),
R"(
Return the partial shape of element i
Parameters
----------
i : int
index : int
element index
Returns
@ -551,7 +551,7 @@ void regclass_graph_Model(py::module m) {
)");
function.def("input", (ov::Output<ov::Node>(ov::Model::*)()) & ov::Model::input);
function.def("input", (ov::Output<ov::Node>(ov::Model::*)(size_t)) & ov::Model::input, py::arg("i"));
function.def("input", (ov::Output<ov::Node>(ov::Model::*)(size_t)) & ov::Model::input, py::arg("index"));
function.def("input",
(ov::Output<ov::Node>(ov::Model::*)(const std::string&)) & ov::Model::input,
@ -559,7 +559,9 @@ void regclass_graph_Model(py::module m) {
function.def("input", (ov::Output<const ov::Node>(ov::Model::*)() const) & ov::Model::input);
function.def("input", (ov::Output<const ov::Node>(ov::Model::*)(size_t) const) & ov::Model::input, py::arg("i"));
function.def("input",
(ov::Output<const ov::Node>(ov::Model::*)(size_t) const) & ov::Model::input,
py::arg("index"));
function.def("input",
(ov::Output<const ov::Node>(ov::Model::*)(const std::string&) const) & ov::Model::input,
@ -567,7 +569,7 @@ void regclass_graph_Model(py::module m) {
function.def("output", (ov::Output<ov::Node>(ov::Model::*)()) & ov::Model::output);
function.def("output", (ov::Output<ov::Node>(ov::Model::*)(size_t)) & ov::Model::output, py::arg("i"));
function.def("output", (ov::Output<ov::Node>(ov::Model::*)(size_t)) & ov::Model::output, py::arg("index"));
function.def("output",
(ov::Output<ov::Node>(ov::Model::*)(const std::string&)) & ov::Model::output,
@ -575,7 +577,9 @@ void regclass_graph_Model(py::module m) {
function.def("output", (ov::Output<const ov::Node>(ov::Model::*)() const) & ov::Model::output);
function.def("output", (ov::Output<const ov::Node>(ov::Model::*)(size_t) const) & ov::Model::output, py::arg("i"));
function.def("output",
(ov::Output<const ov::Node>(ov::Model::*)(size_t) const) & ov::Model::output,
py::arg("index"));
function.def("output",
(ov::Output<const ov::Node>(ov::Model::*)(const std::string&) const) & ov::Model::output,

View File

@ -129,14 +129,14 @@ void regclass_graph_Node(py::module m) {
)");
node.def("get_input_tensor",
&ov::Node::get_input_tensor,
py::arg("i"),
py::arg("index"),
py::return_value_policy::reference_internal,
R"(
Returns the tensor for the node's input with index i
Parameters
----------
i : int
index : int
Index of Input.
Returns
@ -166,13 +166,13 @@ void regclass_graph_Node(py::module m) {
)");
node.def("input_value",
&ov::Node::input_value,
py::arg("i"),
py::arg("index"),
R"(
Returns input of the node with index i
Parameters
----------
i : int
index : int
Index of Input.
Returns
@ -202,13 +202,13 @@ void regclass_graph_Node(py::module m) {
)");
node.def("get_output_element_type",
&ov::Node::get_output_element_type,
py::arg("i"),
py::arg("index"),
R"(
Returns the element type for output i
Parameters
----------
i : int
index : int
Index of the output.
Returns
@ -218,13 +218,13 @@ void regclass_graph_Node(py::module m) {
)");
node.def("get_output_shape",
&ov::Node::get_output_shape,
py::arg("i"),
py::arg("index"),
R"(
Returns the shape for output i
Parameters
----------
i : int
index : int
Index of the output.
Returns
@ -234,13 +234,13 @@ void regclass_graph_Node(py::module m) {
)");
node.def("get_output_partial_shape",
&ov::Node::get_output_partial_shape,
py::arg("i"),
py::arg("index"),
R"(
Returns the partial shape for output i
Parameters
----------
i : int
index : int
Index of the output.
Returns
@ -250,14 +250,14 @@ void regclass_graph_Node(py::module m) {
)");
node.def("get_output_tensor",
&ov::Node::get_output_tensor,
py::arg("i"),
py::arg("index"),
py::return_value_policy::reference_internal,
R"(
Returns the tensor for output i
Parameters
----------
i : int
index : int
Index of the output.
Returns

View File

@ -33,7 +33,7 @@ void regclass_graph_PartialShape(py::module m) {
shape.def(py::init<const ov::Shape&>());
shape.def(py::init<const ov::PartialShape&>());
shape.def_static("dynamic", &ov::PartialShape::dynamic, py::arg("r") = ov::Dimension());
shape.def_static("dynamic", &ov::PartialShape::dynamic, py::arg("rank") = ov::Dimension());
shape.def_property_readonly("is_dynamic",
&ov::PartialShape::is_dynamic,
@ -63,14 +63,14 @@ void regclass_graph_PartialShape(py::module m) {
shape.def("compatible",
&ov::PartialShape::compatible,
py::arg("s"),
py::arg("shape"),
R"(
Check whether this shape is compatible with the argument, i.e.,
whether it is possible to merge them.
Parameters
----------
s : PartialShape
shape : PartialShape
The shape to be checked for compatibility with this shape.
@ -81,13 +81,13 @@ void regclass_graph_PartialShape(py::module m) {
)");
shape.def("refines",
&ov::PartialShape::refines,
py::arg("s"),
py::arg("shape"),
R"(
Check whether this shape is a refinement of the argument.
Parameters
----------
s : PartialShape
shape : PartialShape
The shape which is being compared against this shape.
Returns
@ -97,13 +97,13 @@ void regclass_graph_PartialShape(py::module m) {
)");
shape.def("relaxes",
&ov::PartialShape::relaxes,
py::arg("s"),
py::arg("shape"),
R"(
Check whether this shape is a relaxation of the argument.
Parameters
----------
s : PartialShape
shape : PartialShape
The shape which is being compared against this shape.
Returns
@ -113,13 +113,13 @@ void regclass_graph_PartialShape(py::module m) {
)");
shape.def("same_scheme",
&ov::PartialShape::same_scheme,
py::arg("s"),
py::arg("shape"),
R"(
Check whether this shape represents the same scheme as the argument.
Parameters
----------
s : PartialShape
shape : PartialShape
The shape which is being compared against this shape.
Returns

View File

@ -331,9 +331,9 @@ def test_extract_subgraph():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out").get_input_port(inputPortIndex=0) # in1
place2 = model.get_place_by_tensor_name(tensorName="add_out").get_input_port(inputPortIndex=1) # in2
place3 = model.get_place_by_tensor_name(tensorName="add_out")
place1 = model.get_place_by_tensor_name(tensor_name="add_out").get_input_port(input_port_index=0) # in1
place2 = model.get_place_by_tensor_name(tensor_name="add_out").get_input_port(input_port_index=1) # in2
place3 = model.get_place_by_tensor_name(tensor_name="add_out")
model.extract_subgraph(inputs=[place1, place2], outputs=[place3])
result_func = fe.convert(model)
@ -352,8 +352,8 @@ def test_extract_subgraph_2():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out")
place2 = model.get_place_by_tensor_name(tensorName="out3")
place1 = model.get_place_by_tensor_name(tensor_name="add_out")
place2 = model.get_place_by_tensor_name(tensor_name="out3")
model.extract_subgraph(inputs=[], outputs=[place1, place2])
result_func = fe.convert(model)
@ -372,9 +372,9 @@ def test_extract_subgraph_3():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place2 = model.get_place_by_tensor_name(tensorName="out1")
place3 = model.get_place_by_tensor_name(tensorName="out2")
place1 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
place2 = model.get_place_by_tensor_name(tensor_name="out1")
place3 = model.get_place_by_tensor_name(tensor_name="out2")
model.extract_subgraph(inputs=[place1], outputs=[place2, place3])
result_func = fe.convert(model)
@ -393,13 +393,13 @@ def test_extract_subgraph_4():
model = fe.load("input_model.onnx")
assert model
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place4 = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_tensor_name(tensorName="out2")
place6 = model.get_place_by_tensor_name(tensorName="out4")
out4_tensor = model.get_place_by_tensor_name(tensor_name="out4")
place1 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
place2 = out4_tensor.get_producing_operation().get_input_port(input_port_index=0)
place3 = out4_tensor.get_producing_operation().get_input_port(input_port_index=1)
place4 = model.get_place_by_tensor_name(tensor_name="out1")
place5 = model.get_place_by_tensor_name(tensor_name="out2")
place6 = model.get_place_by_tensor_name(tensor_name="out4")
model.extract_subgraph(inputs=[place1, place2, place3], outputs=[place4, place5, place6])
result_func = fe.convert(model)
@ -418,11 +418,11 @@ def test_extract_subgraph_by_op_place_as_input():
model = fe.load("input_model.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split1")
out4 = model.get_place_by_tensor_name(tensorName="out4")
split_op = model.get_place_by_operation_name(operation_name="split1")
out4 = model.get_place_by_tensor_name(tensor_name="out4")
mul_op = out4.get_producing_operation()
out1 = model.get_place_by_tensor_name(tensorName="out1")
out2 = model.get_place_by_tensor_name(tensorName="out2")
out1 = model.get_place_by_tensor_name(tensor_name="out1")
out2 = model.get_place_by_tensor_name(tensor_name="out2")
model.extract_subgraph(inputs=[split_op, mul_op], outputs=[out1, out2, out4])
result_func = fe.convert(model)
@ -442,9 +442,9 @@ def test_extract_subgraph_by_op_place_as_output():
model = fe.load("input_model.onnx")
assert model
in1_tensor = model.get_place_by_tensor_name(tensorName="in1")
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
in1_tensor = model.get_place_by_tensor_name(tensor_name="in1")
in2_tensor = model.get_place_by_tensor_name(tensor_name="in2")
add_out_tensor = model.get_place_by_tensor_name(tensor_name="add_out")
add_op = add_out_tensor.get_producing_operation()
model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op])
@ -465,8 +465,8 @@ def test_extract_subgraph_by_op_place_as_output_2():
model = fe.load("input_model.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split1")
out4 = model.get_place_by_tensor_name(tensorName="out4")
split_op = model.get_place_by_operation_name(operation_name="split1")
out4 = model.get_place_by_tensor_name(tensor_name="out4")
mul_op = out4.get_producing_operation()
model.extract_subgraph(inputs=[split_op, mul_op], outputs=[])
@ -487,11 +487,11 @@ def test_extract_subgraph_by_port_place_as_output():
model = fe.load("input_model.onnx")
assert model
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_out_tensor = model.get_place_by_tensor_name(tensor_name="add_out")
add_op = add_out_tensor.get_producing_operation()
add_op_out_port = add_op.get_output_port(outputPortIndex=0)
in1_tensor = model.get_place_by_tensor_name(tensorName="in1")
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
add_op_out_port = add_op.get_output_port(output_port_index=0)
in1_tensor = model.get_place_by_tensor_name(tensor_name="in1")
in2_tensor = model.get_place_by_tensor_name(tensor_name="in2")
model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op_out_port])
result_func = fe.convert(model)
@ -511,8 +511,8 @@ def test_override_all_outputs():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out")
place2 = model.get_place_by_tensor_name(tensorName="out3")
place1 = model.get_place_by_tensor_name(tensor_name="add_out")
place2 = model.get_place_by_tensor_name(tensor_name="out3")
model.override_all_outputs(outputs=[place1, place2])
result_func = fe.convert(model)
@ -531,7 +531,7 @@ def test_override_all_outputs_2():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="out4")
place1 = model.get_place_by_tensor_name(tensor_name="out4")
model.override_all_outputs(outputs=[place1])
result_func = fe.convert(model)
@ -551,11 +551,11 @@ def test_override_all_inputs():
assert model
place1 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place4 = model.get_place_by_tensor_name(tensorName="in3")
operation_name="split1", input_port_index=0)
out4_tensor = model.get_place_by_tensor_name(tensor_name="out4")
place2 = out4_tensor.get_producing_operation().get_input_port(input_port_index=0)
place3 = out4_tensor.get_producing_operation().get_input_port(input_port_index=1)
place4 = model.get_place_by_tensor_name(tensor_name="in3")
model.override_all_inputs(inputs=[place1, place2, place3, place4])
result_func = fe.convert(model)
@ -574,10 +574,10 @@ def test_override_all_inputs_exceptions():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
place2 = model.get_place_by_tensor_name(tensorName="in2")
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place4 = model.get_place_by_tensor_name(tensorName="in3")
place1 = model.get_place_by_tensor_name(tensor_name="in1")
place2 = model.get_place_by_tensor_name(tensor_name="in2")
place3 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
place4 = model.get_place_by_tensor_name(tensor_name="in3")
with pytest.raises(Exception) as e:
model.override_all_inputs(inputs=[place1, place2])
@ -596,24 +596,24 @@ def test_is_input_output():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in2")
place1 = model.get_place_by_tensor_name(tensor_name="in2")
assert place1.is_input()
assert not place1.is_output()
place2 = model.get_place_by_tensor_name(tensorName="out2")
place2 = model.get_place_by_tensor_name(tensor_name="out2")
assert not place2.is_input()
assert place2.is_output()
place3 = model.get_place_by_tensor_name(tensorName="add_out")
place3 = model.get_place_by_tensor_name(tensor_name="add_out")
assert not place3.is_input()
assert not place3.is_output()
place4 = model.get_place_by_operation_name_and_input_port(
operationName="split1", inputPortIndex=0)
operation_name="split1", input_port_index=0)
assert not place4.is_input()
assert not place4.is_output()
place5 = model.get_place_by_operation_name(operationName="split1")
place5 = model.get_place_by_operation_name(operation_name="split1")
assert not place5.is_input()
assert not place5.is_output()
@ -626,11 +626,11 @@ def test_set_partial_shape():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
place1 = model.get_place_by_tensor_name(tensor_name="in1")
model.set_partial_shape(place1, PartialShape([8, 16]))
place2 = model.get_place_by_tensor_name(tensorName="in2")
place2 = model.get_place_by_tensor_name(tensor_name="in2")
model.set_partial_shape(place2, PartialShape([8, 16]))
place3 = model.get_place_by_tensor_name(tensorName="in3")
place3 = model.get_place_by_tensor_name(tensor_name="in3")
model.set_partial_shape(place3, PartialShape([4, 6]))
result_func = fe.convert(model)
@ -649,16 +649,16 @@ def test_get_partial_shape():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
place1 = model.get_place_by_tensor_name(tensor_name="in1")
assert model.get_partial_shape(place1) == PartialShape([2, 2])
place2 = model.get_place_by_tensor_name(tensorName="out1")
place2 = model.get_place_by_tensor_name(tensor_name="out1")
assert model.get_partial_shape(place2) == PartialShape([1, 2])
place3 = model.get_place_by_tensor_name(tensorName="add_out")
place3 = model.get_place_by_tensor_name(tensor_name="add_out")
assert model.get_partial_shape(place3) == PartialShape([2, 2])
place4 = model.get_place_by_tensor_name(tensorName="in3")
place4 = model.get_place_by_tensor_name(tensor_name="in3")
model.set_partial_shape(place4, PartialShape([4, 6]))
assert model.get_partial_shape(place4) == PartialShape([4, 6])
assert model.get_partial_shape(place2) == PartialShape([1, 2])
@ -694,33 +694,33 @@ def test_is_equal():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
place1 = model.get_place_by_tensor_name(tensor_name="in1")
assert place1.is_equal(place1)
place2 = model.get_place_by_tensor_name(tensorName="out2")
place2 = model.get_place_by_tensor_name(tensor_name="out2")
assert place2.is_equal(place2)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
out4_tensor = model.get_place_by_tensor_name(tensor_name="out4")
place3 = out4_tensor.get_producing_operation().get_input_port(input_port_index=0)
place4 = out4_tensor.get_producing_operation().get_input_port(input_port_index=0)
assert place3.is_equal(place4)
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place6 = out1_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
out1_tensor = model.get_place_by_tensor_name(tensor_name="out1")
place5 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
place6 = out1_tensor.get_producing_operation().get_input_port(input_port_index=0)
assert place5.is_equal(place6)
place7 = model.get_place_by_tensor_name(tensorName="out4").get_producing_port()
place7 = model.get_place_by_tensor_name(tensor_name="out4").get_producing_port()
assert place7.is_equal(place7)
place8 = model.get_place_by_tensor_name(tensorName="add_out")
place8 = model.get_place_by_tensor_name(tensor_name="add_out")
assert place8.is_equal(place8)
assert not place1.is_equal(place2)
assert not place6.is_equal(place7)
assert not place8.is_equal(place2)
place9 = model.get_place_by_operation_name(operationName="split1")
place9 = model.get_place_by_operation_name(operation_name="split1")
assert place2.get_producing_operation().is_equal(place9)
assert not place9.is_equal(place2)
@ -733,32 +733,32 @@ def test_is_equal_data():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="in1")
place1 = model.get_place_by_tensor_name(tensor_name="in1")
assert place1.is_equal_data(place1)
place2 = model.get_place_by_tensor_name(tensorName="add_out")
place2 = model.get_place_by_tensor_name(tensor_name="add_out")
assert place2.is_equal_data(place2)
place3 = model.get_place_by_tensor_name(tensorName="in2")
place3 = model.get_place_by_tensor_name(tensor_name="in2")
assert not place1.is_equal_data(place3)
assert not place2.is_equal_data(place1)
place4 = place2.get_producing_port()
assert place2.is_equal_data(place4)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place5 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
out4_tensor = model.get_place_by_tensor_name(tensor_name="out4")
place5 = out4_tensor.get_producing_operation().get_input_port(input_port_index=0)
assert place2.is_equal_data(place5)
assert place4.is_equal_data(place5)
place6 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place6 = out4_tensor.get_producing_operation().get_input_port(input_port_index=1)
assert place6.is_equal_data(place5)
place7 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place7 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
assert place7.is_equal_data(place7)
place8 = model.get_place_by_tensor_name(tensorName="out1")
place9 = model.get_place_by_tensor_name(tensorName="out2")
place8 = model.get_place_by_tensor_name(tensor_name="out1")
place9 = model.get_place_by_tensor_name(tensor_name="out2")
place10 = place8.get_producing_port()
assert not place8.is_equal_data(place9)
assert not place9.is_equal_data(place10)
@ -773,16 +773,16 @@ def test_get_place_by_tensor_name():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="out2")
place1 = model.get_place_by_tensor_name(tensor_name="out2")
assert place1
place2 = model.get_place_by_tensor_name(tensorName="add_out")
place2 = model.get_place_by_tensor_name(tensor_name="add_out")
assert place2
place3 = model.get_place_by_tensor_name(tensorName="in1")
place3 = model.get_place_by_tensor_name(tensor_name="in1")
assert place3
assert not model.get_place_by_tensor_name(tensorName="0:add_out")
assert not model.get_place_by_tensor_name(tensor_name="0:add_out")
def test_get_place_by_operation_name():
@ -793,10 +793,10 @@ def test_get_place_by_operation_name():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_operation_name(operationName="split1")
place1 = model.get_place_by_operation_name(operation_name="split1")
assert place1
place2 = model.get_place_by_operation_name(operationName="not_existed")
place2 = model.get_place_by_operation_name(operation_name="not_existed")
assert not place2
@ -807,16 +807,16 @@ def test_get_output_port():
model = fe.load("input_model.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split1")
place1 = split_op.get_output_port(outputPortIndex=0)
place2 = split_op.get_output_port(outputName="out2")
split_op = model.get_place_by_operation_name(operation_name="split1")
place1 = split_op.get_output_port(output_port_index=0)
place2 = split_op.get_output_port(output_name="out2")
assert place1.get_target_tensor().get_names()[0] == "out1"
assert place2.get_target_tensor().get_names()[0] == "out2"
assert not split_op.get_output_port()
assert not split_op.get_output_port(outputPortIndex=3)
assert not split_op.get_output_port(outputName="not_existed")
assert not split_op.get_output_port(output_port_index=3)
assert not split_op.get_output_port(output_name="not_existed")
def test_get_input_port():
@ -826,15 +826,15 @@ def test_get_input_port():
model = fe.load("input_model.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split1")
place1 = split_op.get_input_port(inputPortIndex=0)
split_op = model.get_place_by_operation_name(operation_name="split1")
place1 = split_op.get_input_port(input_port_index=0)
assert place1.get_source_tensor().get_names()[0] == "add_out"
place2 = split_op.get_input_port()
assert place1.is_equal(place2)
assert not split_op.get_input_port(inputPortIndex=1)
assert not split_op.get_input_port(inputName="not_existed")
assert not split_op.get_input_port(input_port_index=1)
assert not split_op.get_input_port(input_name="not_existed")
def test_get_consuming_ports():
@ -844,15 +844,15 @@ def test_get_consuming_ports():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensorName="add_out")
place1 = model.get_place_by_tensor_name(tensor_name="add_out")
add_tensor_consuming_ports = place1.get_consuming_ports()
assert len(add_tensor_consuming_ports) == 3
place2 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place2 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
assert add_tensor_consuming_ports[0].is_equal(place2)
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
out4_tensor = model.get_place_by_tensor_name(tensor_name="out4")
place3 = out4_tensor.get_producing_operation().get_input_port(input_port_index=0)
assert add_tensor_consuming_ports[1].is_equal(place3)
place4 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place4 = out4_tensor.get_producing_operation().get_input_port(input_port_index=1)
assert add_tensor_consuming_ports[2].is_equal(place4)
add_op_consuming_ports = place1.get_producing_operation().get_consuming_ports()
@ -868,16 +868,17 @@ def test_get_consuming_ports_2():
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op = model.get_place_by_operation_name(operation_name="split2")
split_op_consuming_ports = split_op.get_consuming_ports()
assert len(split_op_consuming_ports) == 2
abs_input_port = model.get_place_by_operation_name(operationName="abs1").get_input_port(inputPortIndex=0)
abs_input_port = model.get_place_by_operation_name(operation_name="abs1") \
.get_input_port(input_port_index=0)
assert split_op_consuming_ports[0].is_equal(abs_input_port)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
sin_input_port = out2_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
out2_tensor = model.get_place_by_tensor_name(tensor_name="out2")
sin_input_port = out2_tensor.get_producing_operation().get_input_port(input_port_index=0)
assert split_op_consuming_ports[1].is_equal(sin_input_port)
split_out_port_0 = split_op.get_output_port(outputPortIndex=0)
split_out_port_0 = split_op.get_output_port(output_port_index=0)
split_out_port_0_consuming_ports = split_out_port_0.get_consuming_ports()
assert len(split_out_port_0_consuming_ports) == 1
assert split_out_port_0_consuming_ports[0].is_equal(abs_input_port)
@ -890,12 +891,12 @@ def test_get_producing_operation():
model = fe.load("input_model_2.onnx")
assert model
split_tensor_out_2 = model.get_place_by_tensor_name(tensorName="sp_out2")
split_op = model.get_place_by_operation_name(operationName="split2")
split_tensor_out_2 = model.get_place_by_tensor_name(tensor_name="sp_out2")
split_op = model.get_place_by_operation_name(operation_name="split2")
assert split_tensor_out_2.get_producing_operation().is_equal(split_op)
split_op = model.get_place_by_operation_name(operationName="split2")
split_out_port_2 = split_op.get_output_port(outputPortIndex=1)
split_op = model.get_place_by_operation_name(operation_name="split2")
split_out_port_2 = split_op.get_output_port(output_port_index=1)
assert split_out_port_2.get_producing_operation().is_equal(split_op)
@ -906,22 +907,22 @@ def test_get_producing_operation_2():
model = fe.load("input_model_2.onnx")
assert model
abs_op = model.get_place_by_operation_name(operationName="abs1")
abs_op = model.get_place_by_operation_name(operation_name="abs1")
abs_port_0 = abs_op.get_input_port()
split_op = model.get_place_by_operation_name(operationName="split2")
split_op = model.get_place_by_operation_name(operation_name="split2")
assert abs_port_0.get_producing_operation().is_equal(split_op)
assert abs_op.get_producing_operation().is_equal(split_op)
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_out_tensor = model.get_place_by_tensor_name(tensor_name="add_out")
add_op = add_out_tensor.get_producing_operation()
assert not add_op.get_producing_operation()
split_op_producing_op = split_op.get_producing_operation(inputName="add_out")
split_op_producing_op = split_op.get_producing_operation(input_name="add_out")
assert split_op_producing_op.is_equal(add_op)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
out2_tensor = model.get_place_by_tensor_name(tensor_name="out2")
sin_op = out2_tensor.get_producing_operation()
assert sin_op.get_producing_operation(inputPortIndex=0).is_equal(split_op)
assert sin_op.get_producing_operation(input_port_index=0).is_equal(split_op)
def test_get_consuming_operations():
@ -931,40 +932,40 @@ def test_get_consuming_operations():
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op = model.get_place_by_operation_name(operation_name="split2")
split_op_consuming_ops = split_op.get_consuming_operations()
abs_op = model.get_place_by_operation_name(operationName="abs1")
sin_op = model.get_place_by_tensor_name(tensorName="out2").get_producing_operation()
abs_op = model.get_place_by_operation_name(operation_name="abs1")
sin_op = model.get_place_by_tensor_name(tensor_name="out2").get_producing_operation()
assert len(split_op_consuming_ops) == 2
assert split_op_consuming_ops[0].is_equal(abs_op)
assert split_op_consuming_ops[1].is_equal(sin_op)
split_op_port = split_op.get_input_port(inputPortIndex=0)
split_op_port = split_op.get_input_port(input_port_index=0)
split_op_port_consuming_ops = split_op_port.get_consuming_operations()
assert len(split_op_port_consuming_ops) == 1
assert split_op_port_consuming_ops[0].is_equal(split_op)
add_out_port = model.get_place_by_tensor_name(tensorName="add_out").get_producing_port()
add_out_port = model.get_place_by_tensor_name(tensor_name="add_out").get_producing_port()
add_out_port_consuming_ops = add_out_port.get_consuming_operations()
assert len(add_out_port_consuming_ops) == 1
assert add_out_port_consuming_ops[0].is_equal(split_op)
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
sp_out2_tensor = model.get_place_by_tensor_name(tensor_name="sp_out2")
sp_out2_tensor_consuming_ops = sp_out2_tensor.get_consuming_operations()
assert len(sp_out2_tensor_consuming_ops) == 1
assert sp_out2_tensor_consuming_ops[0].is_equal(sin_op)
out2_tensor = model.get_place_by_tensor_name(tensorName="out2")
out2_tensor = model.get_place_by_tensor_name(tensor_name="out2")
out2_tensor_consuming_ops = out2_tensor.get_consuming_operations()
assert len(out2_tensor_consuming_ops) == 0
out2_port_consuming_ops = out2_tensor.get_producing_port().get_consuming_operations()
assert len(out2_port_consuming_ops) == 0
split_out_1_consuming_ops = split_op.get_consuming_operations(outputPortIndex=1)
split_out_1_consuming_ops = split_op.get_consuming_operations(output_port_index=1)
assert len(split_out_1_consuming_ops) == 1
split_out_sp_out_2_consuming_ops = split_op.get_consuming_operations(outputName="sp_out2")
split_out_sp_out_2_consuming_ops = split_op.get_consuming_operations(output_name="sp_out2")
assert len(split_out_sp_out_2_consuming_ops) == 1
assert split_out_1_consuming_ops[0].is_equal(split_out_sp_out_2_consuming_ops[0])
assert split_out_1_consuming_ops[0].is_equal(sin_op)
@ -977,18 +978,18 @@ def test_get_target_tensor():
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op = model.get_place_by_operation_name(operation_name="split2")
assert not split_op.get_target_tensor()
split_op_tensor_1 = split_op.get_target_tensor(outputPortIndex=1)
sp_out2_tensor = model.get_place_by_tensor_name(tensorName="sp_out2")
split_op_tensor_1 = split_op.get_target_tensor(output_port_index=1)
sp_out2_tensor = model.get_place_by_tensor_name(tensor_name="sp_out2")
assert split_op_tensor_1.is_equal(sp_out2_tensor)
split_tensor_sp_out2 = split_op.get_target_tensor(outputName="sp_out2")
split_tensor_sp_out2 = split_op.get_target_tensor(output_name="sp_out2")
assert split_tensor_sp_out2.is_equal(split_op_tensor_1)
abs_op = model.get_place_by_operation_name(operationName="abs1")
out1_tensor = model.get_place_by_tensor_name(tensorName="out1")
abs_op = model.get_place_by_operation_name(operation_name="abs1")
out1_tensor = model.get_place_by_tensor_name(tensor_name="out1")
assert abs_op.get_target_tensor().is_equal(out1_tensor)
@ -999,18 +1000,18 @@ def test_get_source_tensor():
model = fe.load("input_model_2.onnx")
assert model
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_out_tensor = model.get_place_by_tensor_name(tensor_name="add_out")
add_op = add_out_tensor.get_producing_operation()
assert not add_op.get_source_tensor()
add_op_in_tensor_1 = add_op.get_source_tensor(inputPortIndex=1)
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
add_op_in_tensor_1 = add_op.get_source_tensor(input_port_index=1)
in2_tensor = model.get_place_by_tensor_name(tensor_name="in2")
assert add_op_in_tensor_1.is_equal(in2_tensor)
add_op_in_tensor_in2 = add_op.get_source_tensor(inputName="in2")
add_op_in_tensor_in2 = add_op.get_source_tensor(input_name="in2")
assert add_op_in_tensor_in2.is_equal(in2_tensor)
split_op = model.get_place_by_operation_name(operationName="split2")
split_op = model.get_place_by_operation_name(operation_name="split2")
assert split_op.get_source_tensor().is_equal(add_out_tensor)
@ -1021,11 +1022,11 @@ def test_get_producing_port():
model = fe.load("input_model_2.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split2")
split_op = model.get_place_by_operation_name(operation_name="split2")
split_op_in_port = split_op.get_input_port()
split_op_in_port_prod_port = split_op_in_port.get_producing_port()
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_out_tensor = model.get_place_by_tensor_name(tensor_name="add_out")
add_op = add_out_tensor.get_producing_operation()
add_op_out_port = add_op.get_output_port()
@ -1040,9 +1041,9 @@ def test_get_place_by_operation_name_and_input_port():
model = fe.load("input_model.onnx")
assert model
place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place1 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
sp_out1_tensor = model.get_place_by_tensor_name("out2")
place2 = sp_out1_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place2 = sp_out1_tensor.get_producing_operation().get_input_port(input_port_index=0)
assert place1.is_equal(place2)
@ -1055,9 +1056,9 @@ def test_get_place_by_operation_name_and_output_port():
model = fe.load("input_model_2.onnx")
assert model
place1 = model.get_place_by_operation_name_and_output_port(operationName="split2", outputPortIndex=0)
place1 = model.get_place_by_operation_name_and_output_port(operation_name="split2", output_port_index=0)
sp_out1_tensor = model.get_place_by_tensor_name("sp_out1")
place2 = sp_out1_tensor.get_producing_operation().get_output_port(outputPortIndex=0)
place2 = sp_out1_tensor.get_producing_operation().get_output_port(output_port_index=0)
assert place1.is_equal(place2)
@ -1066,10 +1067,10 @@ def test_not_supported_methods():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
model = fe.load("test_place_names.onnx")
tensor = model.get_place_by_tensor_name(tensorName="add_out")
tensor = model.get_place_by_tensor_name(tensor_name="add_out")
with pytest.raises(Exception) as e:
model.add_name_for_tensor(tensor=tensor, newName="new_name")
model.add_name_for_tensor(tensor=tensor, new_name="new_name")
assert "not applicable for ONNX model" in str(e)
with pytest.raises(Exception) as e:
@ -1084,34 +1085,34 @@ def test_set_name_for_tensor():
old_name = "add_out"
new_name = "add_out_new"
tensor = model.get_place_by_tensor_name(tensorName=old_name)
tensor = model.get_place_by_tensor_name(tensor_name=old_name)
# ignore rename to own name (expect no exception)
model.set_name_for_tensor(tensor=tensor, newName=old_name)
model.set_name_for_tensor(tensor=tensor, new_name=old_name)
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="")
model.set_name_for_tensor(tensor=tensor, new_name="")
assert "name must not be empty" in str(e)
# ONNX model stores tensor info separately for inputs, outputs and between nodes tensors
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="in1")
model.set_name_for_tensor(tensor=tensor, new_name="in1")
assert "already used by another tensor" in str(e)
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="out1")
model.set_name_for_tensor(tensor=tensor, new_name="out1")
assert "already used by another tensor" in str(e)
with pytest.raises(Exception) as e:
model.set_name_for_tensor(tensor=tensor, newName="sub_out")
model.set_name_for_tensor(tensor=tensor, new_name="sub_out")
assert "already used by another tensor" in str(e)
# actual rename
model.set_name_for_tensor(tensor=tensor, newName=new_name)
model.set_name_for_tensor(tensor=tensor, new_name=new_name)
new_tensor = model.get_place_by_tensor_name(tensorName=new_name)
new_tensor = model.get_place_by_tensor_name(tensor_name=new_name)
assert new_tensor
assert new_tensor.is_equal(tensor) # previous Place object holds the handle
old_tensor = model.get_place_by_tensor_name(tensorName=old_name)
old_tensor = model.get_place_by_tensor_name(tensor_name=old_name)
assert old_tensor is None
@ -1122,21 +1123,21 @@ def test_set_name_for_operation_with_name():
old_name = "split1"
new_name = "split1_new"
operation = model.get_place_by_operation_name(operationName=old_name)
operation = model.get_place_by_operation_name(operation_name=old_name)
# ignore rename to own name (expect no exception)
model.set_name_for_operation(operation=operation, newName=old_name)
model.set_name_for_operation(operation=operation, new_name=old_name)
# actual rename
model.set_name_for_operation(operation=operation, newName=new_name)
model.set_name_for_operation(operation=operation, new_name=new_name)
new_operation = model.get_place_by_operation_name(operationName=new_name)
new_operation = model.get_place_by_operation_name(operation_name=new_name)
assert new_operation
assert new_operation.is_equal(operation) # previous Place object holds the handle
# Below test passes for models with unique operation names, what is not required by ONNX standard
# If there were more that one nodes with "split1" name, this test would fail.
old_operation = model.get_place_by_operation_name(operationName=old_name)
old_operation = model.get_place_by_operation_name(operation_name=old_name)
assert old_operation is None
@ -1147,14 +1148,14 @@ def test_set_name_for_operation_without_name():
output_name = "add_out"
new_name = "Add_new"
operation = model.get_place_by_tensor_name(tensorName=output_name).get_producing_operation()
operation = model.get_place_by_tensor_name(tensor_name=output_name).get_producing_operation()
# assure the test is performed on node with empty name
assert not operation.get_names() or len(operation.get_names()) == 0 or not operation.get_names()[0]
# actual rename
model.set_name_for_operation(operation=operation, newName=new_name)
model.set_name_for_operation(operation=operation, new_name=new_name)
new_operation = model.get_place_by_tensor_name(tensorName=output_name).get_producing_operation()
new_operation = model.get_place_by_tensor_name(tensor_name=output_name).get_producing_operation()
assert new_operation
assert new_operation.is_equal(operation) # previous Place object holds the handle
@ -1168,13 +1169,13 @@ def test_free_name_for_operation():
# assure non existent names are ignored (expect no exception)
model.free_name_for_operation("non existent name")
split1 = model.get_place_by_operation_name(operationName=name)
split1 = model.get_place_by_operation_name(operation_name=name)
assert split1
model.free_name_for_operation(name)
operation = model.get_place_by_operation_name(operationName=name)
operation = model.get_place_by_operation_name(operation_name=name)
assert not operation
new_split1 = model.get_place_by_tensor_name(tensorName="out1").get_producing_operation()
new_split1 = model.get_place_by_tensor_name(tensor_name="out1").get_producing_operation()
assert split1.is_equal(new_split1)
@ -1184,16 +1185,16 @@ def test_set_name_for_dimension():
model = fe.load("test_place_names.onnx")
dim_name = "batch_size"
input1 = model.get_place_by_tensor_name(tensorName="in1")
input1 = model.get_place_by_tensor_name(tensor_name="in1")
model.set_name_for_dimension(input1, 0, dim_name)
assert model.get_partial_shape(input1) == PartialShape([-1, 2])
output1 = model.get_place_by_tensor_name(tensorName="out1")
output1 = model.get_place_by_tensor_name(tensor_name="out1")
model.set_name_for_dimension(output1, 1, dim_name)
assert model.get_partial_shape(output1) == PartialShape([1, -1])
# sub_output rank is 2 so setting dim_name at index 3 extends its rank to 4
sub_output = model.get_place_by_tensor_name(tensorName="sub_out")
sub_output = model.get_place_by_tensor_name(tensor_name="sub_out")
model.set_name_for_dimension(sub_output, 3, dim_name)
assert model.get_partial_shape(sub_output) == PartialShape([2, 2, -1, -1])
@ -1201,7 +1202,7 @@ def test_set_name_for_dimension():
model.set_name_for_dimension(input1, 0, "")
assert "name must not be empty" in str(e)
one_const = model.get_place_by_tensor_name(tensorName="one_const")
one_const = model.get_place_by_tensor_name(tensor_name="one_const")
with pytest.raises(Exception) as e:
model.set_name_for_dimension(one_const, 0, dim_name)
assert "ONNX initializer shape dimension cannot be dynamic." in str(e)

View File

@ -131,7 +131,7 @@ def test_partial_shape():
assert list(ps.get_max_shape()) == []
assert repr(ps) == "<PartialShape: ...>"
ps = PartialShape.dynamic(r=Dimension(2))
ps = PartialShape.dynamic(rank=Dimension(2))
assert not ps.is_static
assert ps.is_dynamic
assert ps.rank == 2