[PYTHON API] fix direct access to model input shape (#9562)

* Copy port shape to avoid direct access to shape buffer

* Add __eq__ for shape

* Add tests

* Fix getters

* add __setitem__

* Add a note about copy
This commit is contained in:
Alexey Lebedev 2022-01-12 17:48:10 +03:00 committed by GitHub
parent 6840d945c7
commit 4a6575b4b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 6 deletions

View File

@ -78,23 +78,25 @@ void regclass_graph_Output(py::module m, std::string typestring)
)"); )");
output.def("get_shape", output.def("get_shape",
&ov::Output<VT>::get_shape, &ov::Output<VT>::get_shape,
py::return_value_policy::copy,
R"( R"(
The shape of the output referred to by this output handle. The shape of the output referred to by this output handle.
Returns Returns
---------- ----------
get_shape : Shape get_shape : Shape
Shape of the output. Copy of Shape of the output.
)"); )");
output.def("get_partial_shape", output.def("get_partial_shape",
&ov::Output<VT>::get_partial_shape, &ov::Output<VT>::get_partial_shape,
py::return_value_policy::copy,
R"( R"(
The partial shape of the output referred to by this output handle. The partial shape of the output referred to by this output handle.
Returns Returns
---------- ----------
get_partial_shape : PartialShape get_partial_shape : PartialShape
PartialShape of the output. Copy of PartialShape of the output.
)"); )");
output.def("get_target_inputs", output.def("get_target_inputs",
&ov::Output<VT>::get_target_inputs, &ov::Output<VT>::get_target_inputs,
@ -137,8 +139,8 @@ void regclass_graph_Output(py::module m, std::string typestring)
output.def_property_readonly("any_name", &ov::Output<VT>::get_any_name); output.def_property_readonly("any_name", &ov::Output<VT>::get_any_name);
output.def_property_readonly("names", &ov::Output<VT>::get_names); output.def_property_readonly("names", &ov::Output<VT>::get_names);
output.def_property_readonly("element_type", &ov::Output<VT>::get_element_type); output.def_property_readonly("element_type", &ov::Output<VT>::get_element_type);
output.def_property_readonly("shape", &ov::Output<VT>::get_shape); output.def_property_readonly("shape", &ov::Output<VT>::get_shape, py::return_value_policy::copy);
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape); output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape, py::return_value_policy::copy);
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs); output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor); output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
output.def_property_readonly("rt_info", output.def_property_readonly("rt_info",

View File

@ -196,6 +196,10 @@ void regclass_graph_PartialShape(py::module m) {
return self.size(); return self.size();
}); });
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension::value_type d) {
self[key] = d;
});
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension& d) { shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension& d) {
self[key] = d; self[key] = d;
}); });

View File

@ -22,10 +22,16 @@ void regclass_graph_Shape(py::module m) {
shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths")); shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths"));
shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths")); shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths"));
shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths")); shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths"));
shape.def(
"__eq__",
[](const ov::Shape& a, const ov::Shape& b) {
return a == b;
},
py::is_operator());
shape.def("__len__", [](const ov::Shape& v) { shape.def("__len__", [](const ov::Shape& v) {
return v.size(); return v.size();
}); });
shape.def("__setitem__", [](ov::Shape& self, size_t key, size_t d) { shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension::value_type d) {
self[key] = d; self[key] = d;
}); });
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension d) { shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension d) {

View File

@ -5,7 +5,7 @@ import numpy as np
import pytest import pytest
import openvino.runtime.opset8 as ops import openvino.runtime.opset8 as ops
from openvino.runtime import Model, Tensor, Output, Dimension,\ from openvino.runtime import Core, Model, Tensor, Output, Dimension,\
Layout, Type, PartialShape, Shape, set_batch, get_batch Layout, Type, PartialShape, Shape, set_batch, get_batch
@ -349,3 +349,15 @@ def test_reshape_with_names():
for input in model.inputs: for input in model.inputs:
model.reshape({input.any_name: new_shape}) model.reshape({input.any_name: new_shape})
assert input.partial_shape == new_shape assert input.partial_shape == new_shape
def test_reshape(device):
shape = Shape([1, 10])
param = ops.parameter(shape, dtype=np.float32)
model = Model(ops.relu(param), [param])
ref_shape = model.input().partial_shape
ref_shape[0] = 3
model.reshape(ref_shape)
core = Core()
compiled = core.compile_model(model, device)
assert compiled.input().partial_shape == ref_shape

View File

@ -220,6 +220,16 @@ def test_partial_shape_equals():
shape = Shape([1, 2, 3]) shape = Shape([1, 2, 3])
ps = PartialShape([1, 2, 3]) ps = PartialShape([1, 2, 3])
assert shape == ps assert shape == ps
assert shape == ps.to_shape()
def test_input_shape_read_only():
shape = Shape([1, 10])
param = ov.parameter(shape, dtype=np.float32)
model = Model(ov.relu(param), [param])
ref_shape = model.input().shape
ref_shape[0] = Dimension(3)
assert model.input().shape == shape
def test_repr_dynamic_shape(): def test_repr_dynamic_shape():