Fix the __repr__ function for dynamic shapes (#1210)
This commit is contained in:
@@ -51,9 +51,16 @@ void regclass_pyngraph_Function(py::module m)
|
||||
function.def("is_dynamic", &ngraph::Function::is_dynamic);
|
||||
function.def("__repr__", [](const ngraph::Function& self) {
|
||||
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
|
||||
std::string shape =
|
||||
py::cast(self.get_output_shape(0)).attr("__str__")().cast<std::string>();
|
||||
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>";
|
||||
std::stringstream shapes_ss;
|
||||
for (size_t i = 0; i < self.get_output_size(); ++i)
|
||||
{
|
||||
if (i > 0)
|
||||
{
|
||||
shapes_ss << ", ";
|
||||
}
|
||||
shapes_ss << self.get_output_partial_shape(i);
|
||||
}
|
||||
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shapes_ss.str() + ")>";
|
||||
});
|
||||
function.def_static("from_capsule", [](py::object* capsule) {
|
||||
// get the underlying PyObject* which is a PyCapsule pointer
|
||||
|
||||
@@ -65,7 +65,7 @@ void regclass_pyngraph_Node(py::module m)
|
||||
{
|
||||
shapes_ss << ", ";
|
||||
}
|
||||
shapes_ss << py::cast(self.get_output_shape(i)).attr("__str__")().cast<std::string>();
|
||||
shapes_ss << self.get_output_partial_shape(i);
|
||||
}
|
||||
return "<" + type_name + ": '" + self.get_friendly_name() + "' (" + shapes_ss.str() + ")>";
|
||||
});
|
||||
|
||||
@@ -31,7 +31,8 @@ void regclass_pyngraph_op_Parameter(py::module m)
|
||||
parameter.doc() = "ngraph.impl.op.Parameter wraps ngraph::op::Parameter";
|
||||
parameter.def("__repr__", [](const ngraph::Node& self) {
|
||||
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
|
||||
std::string shape = py::cast(self.get_shape()).attr("__str__")().cast<std::string>();
|
||||
std::string shape =
|
||||
py::cast(self.get_output_partial_shape(0)).attr("__str__")().cast<std::string>();
|
||||
std::string type = self.get_element_type().c_type_string();
|
||||
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ", " + type +
|
||||
")>";
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
model_version: 1
|
||||
graph {
|
||||
node {
|
||||
name: "multiplication"
|
||||
input: "A"
|
||||
input: "B"
|
||||
output: "mul_out"
|
||||
op_type: "Mul"
|
||||
}
|
||||
node {
|
||||
name: "addition"
|
||||
input: "mul_out"
|
||||
input: "C"
|
||||
output: "add_out"
|
||||
op_type: "Add"
|
||||
}
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "batch"
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "batch"
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "C"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "batch"
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "add_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "batch"
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
name: "simple_dyn_shapes_graph"
|
||||
}
|
||||
opset_import {
|
||||
domain: ""
|
||||
version: 7
|
||||
}
|
||||
@@ -15,6 +15,7 @@
|
||||
# ******************************************************************************
|
||||
|
||||
from ngraph.impl import Dimension, PartialShape, Shape
|
||||
import os
|
||||
|
||||
|
||||
def test_dimension():
|
||||
@@ -222,3 +223,16 @@ def test_partial_shape_equals():
|
||||
shape = Shape([1, 2, 3])
|
||||
ps = PartialShape([1, 2, 3])
|
||||
assert shape == ps
|
||||
|
||||
|
||||
def test_repr_dynamic_shape():
|
||||
from ngraph.impl.onnx_import import import_onnx_model_file
|
||||
model_path = os.path.join(os.path.dirname(__file__), "models/ab_plus_c_dynamic.prototxt")
|
||||
function = import_onnx_model_file(model_path)
|
||||
|
||||
assert repr(function) == "<Function: 'simple_dyn_shapes_graph' ({?,2})>"
|
||||
|
||||
ops = function.get_ordered_ops()
|
||||
|
||||
for op in ops:
|
||||
assert "{?,2}" in repr(op)
|
||||
|
||||
Reference in New Issue
Block a user