[PYTHON API] reshape helper (#10402)

* Add reshape helper

* add dimension(range)

* Add partial_shape helper

* Fix code style

* fix comments

* Split reshape on several overloads

* Fix code style

* correct exception

* remove range support

* fix code style

* Add exception

* Dimension from str, PartialShape from str, reshape(str) support

* Apply review comments

* Add default init for shape

* Add PS syntax examples

* Remove pshape parsing from benchmark_app

* Update src/bindings/python/src/pyopenvino/graph/model.cpp

Co-authored-by: Sergey Lyalin <sergey.lyalin@intel.com>

* Update src/bindings/python/src/pyopenvino/graph/model.cpp

Co-authored-by: Sergey Lyalin <sergey.lyalin@intel.com>

* Apply suggestions from code review

Co-authored-by: Sergey Lyalin <sergey.lyalin@intel.com>

Co-authored-by: Sergey Lyalin <sergey.lyalin@intel.com>
This commit is contained in:
Alexey Lebedev 2022-02-22 14:48:55 +03:00 committed by GitHub
parent 991c9db1c1
commit a3004e7d80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 354 additions and 74 deletions

View File

@ -6,6 +6,8 @@
#include <unordered_map> #include <unordered_map>
#include "openvino/util/common_util.hpp"
#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_ #define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
namespace Common { namespace Common {
@ -88,6 +90,108 @@ ov::Tensor tensor_from_numpy(py::array& array, bool shared_memory) {
return tensor; return tensor;
} }
ov::PartialShape partial_shape_from_list(const py::list& shape) {
using value_type = ov::Dimension::value_type;
ov::PartialShape pshape;
for (py::handle dim : shape) {
if (py::isinstance<py::int_>(dim)) {
pshape.insert(pshape.end(), ov::Dimension(dim.cast<value_type>()));
} else if (py::isinstance<py::str>(dim)) {
pshape.insert(pshape.end(), Common::dimension_from_str(dim.cast<std::string>()));
} else if (py::isinstance<ov::Dimension>(dim)) {
pshape.insert(pshape.end(), dim.cast<ov::Dimension>());
} else if (py::isinstance<py::list>(dim) || py::isinstance<py::tuple>(dim)) {
py::list bounded_dim = dim.cast<py::list>();
if (bounded_dim.size() != 2) {
throw py::type_error("Two elements are expected in tuple(lower, upper) for dynamic dimension, but " +
std::to_string(bounded_dim.size()) + " elements were given.");
}
if (!(py::isinstance<py::int_>(bounded_dim[0]) && py::isinstance<py::int_>(bounded_dim[1]))) {
throw py::type_error("Incorrect pair of types (" + std::string(bounded_dim[0].get_type().str()) + ", " +
std::string(bounded_dim[1].get_type().str()) +
") for dynamic dimension, ints are expected.");
}
pshape.insert(pshape.end(),
ov::Dimension(bounded_dim[0].cast<value_type>(), bounded_dim[1].cast<value_type>()));
} else {
throw py::type_error("Incorrect type " + std::string(dim.get_type().str()) +
" for dimension. Expected types are: "
"int, str, openvino.runtime.Dimension, list/tuple with lower and upper values for "
"dynamic dimension.");
}
}
return pshape;
}
bool check_all_digits(const std::string& value) {
auto val = ov::util::trim(value);
for (const auto& c : val) {
if (!std::isdigit(c) || c == '-') {
return false;
}
}
return true;
}
template <class T>
T stringToType(const std::string& valStr) {
T ret{0};
std::istringstream ss(valStr);
if (!ss.eof()) {
ss >> ret;
}
return ret;
}
ov::Dimension dimension_from_str(const std::string& value) {
using value_type = ov::Dimension::value_type;
auto val = ov::util::trim(value);
if (val == "?" || val == "-1") {
return {-1};
}
if (val.find("..") == std::string::npos) {
OPENVINO_ASSERT(Common::check_all_digits(val), "Cannot parse dimension: \"", val, "\"");
return {Common::stringToType<value_type>(val)};
}
std::string min_value_str = val.substr(0, val.find(".."));
OPENVINO_ASSERT(Common::check_all_digits(min_value_str), "Cannot parse min bound: \"", min_value_str, "\"");
value_type min_value;
if (min_value_str.empty()) {
min_value = 0;
} else {
min_value = Common::stringToType<value_type>(min_value_str);
}
std::string max_value_str = val.substr(val.find("..") + 2);
value_type max_value;
if (max_value_str.empty()) {
max_value = -1;
} else {
max_value = Common::stringToType<value_type>(max_value_str);
}
OPENVINO_ASSERT(Common::check_all_digits(max_value_str), "Cannot parse max bound: \"", max_value_str, "\"");
return {min_value, max_value};
}
ov::PartialShape partial_shape_from_str(const std::string& value) {
auto val = ov::util::trim(value);
if (val == "...") {
return ov::PartialShape::dynamic();
}
ov::PartialShape res;
std::stringstream ss(val);
std::string field;
while (getline(ss, field, ',')) {
OPENVINO_ASSERT(!field.empty(), "Cannot get vector of dimensions! \"", val, "\" is incorrect");
res.insert(res.end(), Common::dimension_from_str(field));
}
return res;
}
py::array as_contiguous(py::array& array, ov::element::Type type) { py::array as_contiguous(py::array& array, ov::element::Type type) {
switch (type) { switch (type) {
// floating // floating

View File

@ -33,6 +33,12 @@ ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape);
ov::Tensor tensor_from_numpy(py::array& array, bool shared_memory); ov::Tensor tensor_from_numpy(py::array& array, bool shared_memory);
ov::PartialShape partial_shape_from_list(const py::list& shape);
ov::PartialShape partial_shape_from_str(const std::string& value);
ov::Dimension dimension_from_str(const std::string& value);
py::array as_contiguous(py::array& array, ov::element::Type type); py::array as_contiguous(py::array& array, ov::element::Type type);
const ov::Tensor& cast_to_tensor(const py::handle& tensor); const ov::Tensor& cast_to_tensor(const py::handle& tensor);

View File

@ -11,6 +11,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "pyopenvino/core/common.hpp"
#include "pyopenvino/graph/dimension.hpp" #include "pyopenvino/graph/dimension.hpp"
namespace py = pybind11; namespace py = pybind11;
@ -41,6 +42,10 @@ void regclass_graph_Dimension(py::module m) {
:type max_dimension: int :type max_dimension: int
)"); )");
dim.def(py::init([](const std::string& value) {
return Common::dimension_from_str(value);
}));
dim.def_static("dynamic", &ov::Dimension::dynamic); dim.def_static("dynamic", &ov::Dimension::dynamic);
dim.def_property_readonly("is_dynamic", dim.def_property_readonly("is_dynamic",

View File

@ -264,48 +264,113 @@ void regclass_graph_Model(py::module m) {
[](ov::Model& self, const ov::PartialShape& partial_shape) { [](ov::Model& self, const ov::PartialShape& partial_shape) {
self.reshape(partial_shape); self.reshape(partial_shape);
}, },
py::arg("partial_shapes"), py::arg("partial_shape"),
R"( R"(
:param partial_shapes: Index of Output. :param partial_shape: New shape.
:type partial_shapes: PartialShape :type partial_shape: PartialShape
:return : void :return : void
)"); )");
function.def( function.def(
"reshape", "reshape",
[](ov::Model& self, const std::map<size_t, ov::PartialShape>& partial_shapes) { [](ov::Model& self, const py::list& partial_shape) {
self.reshape(partial_shapes); self.reshape(Common::partial_shape_from_list(partial_shape));
}, },
py::arg("partial_shapes"), py::arg("partial_shape"),
R"( R"(
:param partial_shape: New shape.
:param partial_shapes: Index of Output. :type partial_shape: list
:type partial_shapes: Dict[int, PartialShape] :return : void
:return: void
)"); )");
function.def( function.def(
"reshape", "reshape",
[](ov::Model& self, const std::map<std::string, ov::PartialShape>& partial_shapes) { [](ov::Model& self, const py::tuple& partial_shape) {
self.reshape(partial_shapes); self.reshape(Common::partial_shape_from_list(partial_shape.cast<py::list>()));
}, },
py::arg("partial_shapes"), py::arg("partial_shape"),
R"( R"(
:param partial_shapes: Index of Output. :param partial_shape: New shape.
:type partial_shapes: Dict[string, PartialShape] :type partial_shape: tuple
:return: void :return : void
)"); )");
function.def( function.def(
"reshape", "reshape",
[](ov::Model& self, const std::map<ov::Output<ov::Node>, ov::PartialShape>& partial_shapes) { [](ov::Model& self, const std::string& partial_shape) {
self.reshape(partial_shapes); self.reshape(Common::partial_shape_from_str(partial_shape));
},
py::arg("partial_shape"),
R"(
:param partial_shape: New shape.
:type partial_shape: str
:return : void
)");
function.def(
"reshape",
[](ov::Model& self, const py::dict& partial_shapes) {
std::map<ov::Output<ov::Node>, ov::PartialShape> new_shapes;
for (const auto& item : partial_shapes) {
std::pair<ov::Output<ov::Node>, ov::PartialShape> new_shape;
// check keys
if (py::isinstance<py::int_>(item.first)) {
new_shape.first = self.input(item.first.cast<size_t>());
} else if (py::isinstance<py::str>(item.first)) {
new_shape.first = self.input(item.first.cast<std::string>());
} else if (py::isinstance<ov::Output<ov::Node>>(item.first)) {
new_shape.first = item.first.cast<ov::Output<ov::Node>>();
} else {
throw py::type_error("Incorrect key type " + std::string(item.first.get_type().str()) +
" to reshape a model, expected keys as openvino.runtime.Output, int or str.");
}
// check values
if (py::isinstance<ov::PartialShape>(item.second)) {
new_shape.second = item.second.cast<ov::PartialShape>();
} else if (py::isinstance<py::list>(item.second) || py::isinstance<py::tuple>(item.second)) {
new_shape.second = Common::partial_shape_from_list(item.second.cast<py::list>());
} else if (py::isinstance<py::str>(item.second)) {
new_shape.second = Common::partial_shape_from_str(item.second.cast<std::string>());
} else {
throw py::type_error(
"Incorrect value type " + std::string(item.second.get_type().str()) +
" to reshape a model, expected values as openvino.runtime.PartialShape, str, list or tuple.");
}
new_shapes.insert(new_shape);
}
self.reshape(new_shapes);
}, },
py::arg("partial_shapes"), py::arg("partial_shapes"),
R"( R"( Reshape model inputs.
:param partial_shapes: Index of Output.
:type partial_shapes: Dict[Output, PartialShape] The allowed types of keys in the `partial_shapes` dictionary are:
:return: void
(1) `int`, input index
(2) `str`, input tensor name
(3) `openvino.runtime.Output`
The allowed types of values in the `partial_shapes` are:
(1) `openvino.runtime.PartialShape`
(2) `list` consisting of dimensions
(3) `tuple` consisting of dimensions
(4) `str`, string representation of `openvino.runtime.PartialShape`
When list or tuple are used to describe dimensions, each dimension can be written in form:
(1) non-negative `int` which means static value for the dimension
(2) `[min, max]`, dynamic dimension where `min` specifies lower bound and `max` specifies upper bound; the range includes both `min` and `max`; using `-1` for `min` or `max` means no known bound
(3) `(min, max)`, the same as above
(4) `-1` is a dynamic dimension without known bounds
(4) `openvino.runtime.Dimension`
(5) `str` using next syntax:
'?' - to define fully dinamic dimension
'1' - to define dimension which length is 1
'1..10' - to define bounded dimension
'..10' or '1..' to define dimension with only lower or only upper limit
:param partial_shapes: New shapes.
:type partial_shapes: Dict[keys, values]
)"); )");
function.def("get_output_size", function.def("get_output_size",

View File

@ -13,6 +13,7 @@
#include "openvino/core/dimension.hpp" // ov::Dimension #include "openvino/core/dimension.hpp" // ov::Dimension
#include "openvino/core/shape.hpp" // ov::Shape #include "openvino/core/shape.hpp" // ov::Shape
#include "pyopenvino/core/common.hpp"
#include "pyopenvino/graph/partial_shape.hpp" #include "pyopenvino/graph/partial_shape.hpp"
namespace py = pybind11; namespace py = pybind11;
@ -23,15 +24,17 @@ void regclass_graph_PartialShape(py::module m) {
py::class_<ov::PartialShape, std::shared_ptr<ov::PartialShape>> shape(m, "PartialShape"); py::class_<ov::PartialShape, std::shared_ptr<ov::PartialShape>> shape(m, "PartialShape");
shape.doc() = "openvino.runtime.PartialShape wraps ov::PartialShape"; shape.doc() = "openvino.runtime.PartialShape wraps ov::PartialShape";
shape.def(py::init([](const std::vector<int64_t>& dimensions) {
return ov::PartialShape(std::vector<ov::Dimension>(dimensions.begin(), dimensions.end()));
}));
shape.def(py::init<const std::initializer_list<size_t>&>());
shape.def(py::init<const std::vector<size_t>&>());
shape.def(py::init<const std::initializer_list<ov::Dimension>&>());
shape.def(py::init<const std::vector<ov::Dimension>&>());
shape.def(py::init<const ov::Shape&>()); shape.def(py::init<const ov::Shape&>());
shape.def(py::init<const ov::PartialShape&>()); shape.def(py::init<const ov::PartialShape&>());
shape.def(py::init([](py::list& shape) {
return Common::partial_shape_from_list(shape);
}));
shape.def(py::init([](py::tuple& shape) {
return Common::partial_shape_from_list(shape.cast<py::list>());
}));
shape.def(py::init([](const std::string& shape) {
return Common::partial_shape_from_str(shape);
}));
shape.def_static("dynamic", &ov::PartialShape::dynamic, py::arg("rank") = ov::Dimension()); shape.def_static("dynamic", &ov::PartialShape::dynamic, py::arg("rank") = ov::Dimension());

View File

@ -19,6 +19,7 @@ namespace py = pybind11;
void regclass_graph_Shape(py::module m) { void regclass_graph_Shape(py::module m) {
py::class_<ov::Shape, std::shared_ptr<ov::Shape>> shape(m, "Shape"); py::class_<ov::Shape, std::shared_ptr<ov::Shape>> shape(m, "Shape");
shape.doc() = "openvino.runtime.Shape wraps ov::Shape"; shape.doc() = "openvino.runtime.Shape wraps ov::Shape";
shape.def(py::init<>());
shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths")); shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths"));
shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths")); shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths"));
shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths")); shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths"));

View File

@ -361,3 +361,76 @@ def test_reshape(device):
core = Core() core = Core()
compiled = core.compile_model(model, device) compiled = core.compile_model(model, device)
assert compiled.input().partial_shape == ref_shape assert compiled.input().partial_shape == ref_shape
def test_reshape_with_python_types(device):
model = create_test_model()
def check_shape(new_shape):
for input in model.inputs:
assert input.partial_shape == new_shape
shape1 = [1, 4]
new_shapes = {input: shape1 for input in model.inputs}
model.reshape(new_shapes)
check_shape(PartialShape(shape1))
shape2 = [1, 6]
new_shapes = {input.any_name: shape2 for input in model.inputs}
model.reshape(new_shapes)
check_shape(PartialShape(shape2))
shape3 = [1, 8]
new_shapes = {i: shape3 for i, input in enumerate(model.inputs)}
model.reshape(new_shapes)
check_shape(PartialShape(shape3))
shape4 = [1, -1]
new_shapes = {input: shape4 for input in model.inputs}
model.reshape(new_shapes)
check_shape(PartialShape([Dimension(1), Dimension(-1)]))
shape5 = [1, (1, 10)]
new_shapes = {input: shape5 for input in model.inputs}
model.reshape(new_shapes)
check_shape(PartialShape([Dimension(1), Dimension(1, 10)]))
shape6 = [Dimension(3), Dimension(3, 10)]
new_shapes = {input: shape6 for input in model.inputs}
model.reshape(new_shapes)
check_shape(PartialShape(shape6))
shape7 = "1..10, ?"
new_shapes = {input: shape7 for input in model.inputs}
model.reshape(new_shapes)
check_shape(PartialShape(shape7))
# reshape mixed keys
shape8 = [(1, 20), -1]
new_shapes = {"data1": shape8, 1: shape8}
model.reshape(new_shapes)
check_shape(PartialShape([Dimension(1, 20), Dimension(-1)]))
# reshape with one input
param = ops.parameter([1, 3, 28, 28])
model = Model(ops.relu(param), [param])
shape9 = [-1, 3, (28, 56), (28, 56)]
model.reshape(shape9)
check_shape(PartialShape([Dimension(-1), Dimension(3), Dimension(28, 56), Dimension(28, 56)]))
shape10 = "?,3,..224,..224"
model.reshape(shape10)
check_shape(PartialShape([Dimension(-1), Dimension(3), Dimension(-1, 224), Dimension(-1, 224)]))
# check exceptions
shape10 = [1, 1, 1, 1]
with pytest.raises(TypeError) as e:
model.reshape({model.input().node: shape10})
assert "Incorrect key type <class 'openvino.pyopenvino.op.Parameter'> to reshape a model, " \
"expected keys as openvino.runtime.Output, int or str." in str(e.value)
with pytest.raises(TypeError) as e:
model.reshape({0: range(1, 9)})
assert "Incorrect value type <class 'range'> to reshape a model, " \
"expected values as openvino.runtime.PartialShape, str, list or tuple." in str(e.value)

View File

@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import numpy as np import numpy as np
import pytest
import openvino.runtime.opset8 as ov import openvino.runtime.opset8 as ov
from openvino.runtime import Dimension, Model, PartialShape, Shape from openvino.runtime import Dimension, Model, PartialShape, Shape
@ -76,6 +77,33 @@ def test_dimension_comparisons():
assert not d2.compatible(d1) assert not d2.compatible(d1)
assert not d2.same_scheme(d1) assert not d2.same_scheme(d1)
d = Dimension("?")
assert d == Dimension()
d = Dimension("1")
assert d == Dimension(1)
d = Dimension("..10")
assert d == Dimension(-1, 10)
d = Dimension("10..")
assert d == Dimension(10, -1)
d = Dimension("5..10")
assert d == Dimension(5, 10)
with pytest.raises(RuntimeError) as e:
d = Dimension("C")
assert 'Cannot parse dimension: "C"' in str(e.value)
with pytest.raises(RuntimeError) as e:
d = Dimension("?..5")
assert 'Cannot parse min bound: "?"' in str(e.value)
with pytest.raises(RuntimeError) as e:
d = Dimension("5..?")
assert 'Cannot parse max bound: "?"' in str(e.value)
def test_partial_shape(): def test_partial_shape():
ps = PartialShape([1, 2, 3, 4]) ps = PartialShape([1, 2, 3, 4])
@ -140,6 +168,40 @@ def test_partial_shape():
assert list(ps.get_max_shape())[0] > 1000000000 assert list(ps.get_max_shape())[0] > 1000000000
assert repr(ps) == "<PartialShape: {?,?}>" assert repr(ps) == "<PartialShape: {?,?}>"
shape_list = [(1, 10), [2, 5], 4, Dimension(2), "..10"]
ref_ps = PartialShape([Dimension(1, 10), Dimension(2, 5), Dimension(4), Dimension(2), Dimension(-1, 10)])
assert PartialShape(shape_list) == ref_ps
assert PartialShape(tuple(shape_list)) == ref_ps
with pytest.raises(TypeError) as e:
PartialShape([(1, 2, 3)])
assert "Two elements are expected in tuple(lower, upper) " \
"for dynamic dimension, but 3 elements were given." in str(e.value)
with pytest.raises(TypeError) as e:
PartialShape([("?", "?")])
assert "Incorrect pair of types (<class 'str'>, <class 'str'>) " \
"for dynamic dimension, ints are expected." in str(e.value)
with pytest.raises(TypeError) as e:
PartialShape([range(10)])
assert "Incorrect type <class 'range'> for dimension. Expected types are: " \
"int, str, openvino.runtime.Dimension, list/tuple with lower " \
"and upper values for dynamic dimension." in str(e.value)
ps = PartialShape("...")
assert ps == PartialShape.dynamic()
ps = PartialShape("?, 3, ..224, 28..224")
assert ps == PartialShape([Dimension(-1), Dimension(3), Dimension(-1, 224), Dimension(28, 224)])
with pytest.raises(RuntimeError) as e:
ps = PartialShape("?,,3")
assert 'Cannot get vector of dimensions! "?,,3" is incorrect' in str(e.value)
shape = Shape()
assert len(shape) == 0
def test_partial_shape_compatible(): def test_partial_shape_compatible():
ps1 = PartialShape.dynamic() ps1 = PartialShape.dynamic()

View File

@ -375,9 +375,9 @@ def get_data_shapes_map(data_shape_string, input_names):
input_name = match[:match.find('[')] input_name = match[:match.find('[')]
shapes = re.findall(r'\[(.*?)\]', match[len(input_name):]) shapes = re.findall(r'\[(.*?)\]', match[len(input_name):])
if input_name: if input_name:
return_value[input_name] = list(parse_partial_shape(shape_str) for shape_str in shapes) return_value[input_name] = list(PartialShape(shape_str) for shape_str in shapes)
else: else:
data_shapes = list(parse_partial_shape(shape_str) for shape_str in shapes) data_shapes = list(PartialShape(shape_str) for shape_str in shapes)
num_inputs, num_shapes = len(input_names), len(data_shapes) num_inputs, num_shapes = len(input_names), len(data_shapes)
if num_shapes != 1 and num_shapes % num_inputs != 0: if num_shapes != 1 and num_shapes % num_inputs != 0:
raise Exception(f"Number of provided data_shapes is not a multiple of the number of model inputs!") raise Exception(f"Number of provided data_shapes is not a multiple of the number of model inputs!")
@ -505,52 +505,13 @@ class AppInputInfo:
return self.partial_shape.is_dynamic return self.partial_shape.is_dynamic
def parse_partial_shape(shape_str):
dims = []
for dim in shape_str.split(','):
if '.. ' in dim:
range = list(int(d) for d in dim.split('..'))
assert len(range) == 2
dims.append(Dimension(range))
elif dim == '?':
dims.append(Dimension())
else:
dims.append(Dimension(int(dim)))
return PartialShape(dims)
def parse_batch_size(batch_size_str):
if batch_size_str:
error_message = f"Can't parse batch size '{batch_size_str}'"
dims = batch_size_str.split("..")
if len(dims) > 2:
raise Exception(error_message)
elif len(dims) == 2:
range = []
for d in dims:
if d.isnumeric():
range.append(int(d))
else:
raise Exception(error_message)
return Dimension(*range)
else:
if dims[0].lstrip("-").isnumeric():
return Dimension(int(dims[0]))
elif dims[0] == "?":
return Dimension()
else:
raise Exception(error_message)
else:
return Dimension(0)
def get_inputs_info(shape_string, data_shape_string, layout_string, batch_size, scale_string, mean_string, inputs): def get_inputs_info(shape_string, data_shape_string, layout_string, batch_size, scale_string, mean_string, inputs):
input_names = get_input_output_names(inputs) input_names = get_input_output_names(inputs)
input_node_names = get_node_names(inputs) input_node_names = get_node_names(inputs)
shape_map = parse_input_parameters(shape_string, input_names) shape_map = parse_input_parameters(shape_string, input_names)
data_shape_map = get_data_shapes_map(data_shape_string, input_names) data_shape_map = get_data_shapes_map(data_shape_string, input_names)
layout_map = parse_input_parameters(layout_string, input_names) layout_map = parse_input_parameters(layout_string, input_names)
batch_size = parse_batch_size(batch_size) batch_size = Dimension(batch_size)
reshape = False reshape = False
batch_found = False batch_found = False
input_info = [] input_info = []
@ -565,10 +526,10 @@ def get_inputs_info(shape_string, data_shape_string, layout_string, batch_size,
# Shape # Shape
info.original_shape = inputs[i].partial_shape info.original_shape = inputs[i].partial_shape
if info.name in shape_map: if info.name in shape_map:
info.partial_shape = parse_partial_shape(shape_map[info.name]) info.partial_shape = PartialShape(shape_map[info.name])
reshape = True reshape = True
elif info.node_name in shape_map: elif info.node_name in shape_map:
info.partial_shape = parse_partial_shape(shape_map[info.node_name]) info.partial_shape = PartialShape(shape_map[info.node_name])
reshape = True reshape = True
else: else:
info.partial_shape = inputs[i].partial_shape info.partial_shape = inputs[i].partial_shape