Fix the __repr__ function for dynamic shapes (#1210)

This commit is contained in:
Tomasz Dołbniak
2020-07-09 15:52:25 +02:00
committed by GitHub
parent e3f9cf16cd
commit 65657ea5c5
5 changed files with 115 additions and 5 deletions

View File

@@ -51,9 +51,16 @@ void regclass_pyngraph_Function(py::module m)
function.def("is_dynamic", &ngraph::Function::is_dynamic);
function.def("__repr__", [](const ngraph::Function& self) {
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
std::string shape =
py::cast(self.get_output_shape(0)).attr("__str__")().cast<std::string>();
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>";
std::stringstream shapes_ss;
for (size_t i = 0; i < self.get_output_size(); ++i)
{
if (i > 0)
{
shapes_ss << ", ";
}
shapes_ss << self.get_output_partial_shape(i);
}
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shapes_ss.str() + ")>";
});
function.def_static("from_capsule", [](py::object* capsule) {
// get the underlying PyObject* which is a PyCapsule pointer

View File

@@ -65,7 +65,7 @@ void regclass_pyngraph_Node(py::module m)
{
shapes_ss << ", ";
}
shapes_ss << py::cast(self.get_output_shape(i)).attr("__str__")().cast<std::string>();
shapes_ss << self.get_output_partial_shape(i);
}
return "<" + type_name + ": '" + self.get_friendly_name() + "' (" + shapes_ss.str() + ")>";
});

View File

@@ -31,7 +31,8 @@ void regclass_pyngraph_op_Parameter(py::module m)
parameter.doc() = "ngraph.impl.op.Parameter wraps ngraph::op::Parameter";
parameter.def("__repr__", [](const ngraph::Node& self) {
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
std::string shape = py::cast(self.get_shape()).attr("__str__")().cast<std::string>();
std::string shape =
py::cast(self.get_output_partial_shape(0)).attr("__str__")().cast<std::string>();
std::string type = self.get_element_type().c_type_string();
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ", " + type +
")>";

View File

@@ -0,0 +1,88 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
model_version: 1
graph {
node {
name: "multiplication"
input: "A"
input: "B"
output: "mul_out"
op_type: "Mul"
}
node {
name: "addition"
input: "mul_out"
input: "C"
output: "add_out"
op_type: "Add"
}
input {
name: "A"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch"
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch"
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "C"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch"
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "add_out"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_value: 2
}
}
}
}
}
name: "simple_dyn_shapes_graph"
}
opset_import {
domain: ""
version: 7
}

View File

@@ -15,6 +15,7 @@
# ******************************************************************************
from ngraph.impl import Dimension, PartialShape, Shape
import os
def test_dimension():
@@ -222,3 +223,16 @@ def test_partial_shape_equals():
shape = Shape([1, 2, 3])
ps = PartialShape([1, 2, 3])
assert shape == ps
def test_repr_dynamic_shape():
from ngraph.impl.onnx_import import import_onnx_model_file
model_path = os.path.join(os.path.dirname(__file__), "models/ab_plus_c_dynamic.prototxt")
function = import_onnx_model_file(model_path)
assert repr(function) == "<Function: 'simple_dyn_shapes_graph' ({?,2})>"
ops = function.get_ordered_ops()
for op in ops:
assert "{?,2}" in repr(op)