[PYTHON API] fix direct access to model input shape (#9562)

* Copy port shape to avoid direct access to shape buffer

* Add __eq__ for shape

* Add tests

* Fix getters

* add __setitem__

* Add a note about copy
This commit is contained in:
Alexey Lebedev 2022-01-12 17:48:10 +03:00 committed by GitHub
parent 6840d945c7
commit 4a6575b4b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 6 deletions

View File

@ -78,23 +78,25 @@ void regclass_graph_Output(py::module m, std::string typestring)
)");
output.def("get_shape",
&ov::Output<VT>::get_shape,
py::return_value_policy::copy,
R"(
The shape of the output referred to by this output handle.
Returns
----------
get_shape : Shape
Shape of the output.
Copy of Shape of the output.
)");
output.def("get_partial_shape",
&ov::Output<VT>::get_partial_shape,
py::return_value_policy::copy,
R"(
The partial shape of the output referred to by this output handle.
Returns
----------
get_partial_shape : PartialShape
PartialShape of the output.
Copy of PartialShape of the output.
)");
output.def("get_target_inputs",
&ov::Output<VT>::get_target_inputs,
@ -137,8 +139,8 @@ void regclass_graph_Output(py::module m, std::string typestring)
output.def_property_readonly("any_name", &ov::Output<VT>::get_any_name);
output.def_property_readonly("names", &ov::Output<VT>::get_names);
output.def_property_readonly("element_type", &ov::Output<VT>::get_element_type);
output.def_property_readonly("shape", &ov::Output<VT>::get_shape);
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape);
output.def_property_readonly("shape", &ov::Output<VT>::get_shape, py::return_value_policy::copy);
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape, py::return_value_policy::copy);
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
output.def_property_readonly("rt_info",

View File

@ -196,6 +196,10 @@ void regclass_graph_PartialShape(py::module m) {
return self.size();
});
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension::value_type d) {
self[key] = d;
});
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension& d) {
self[key] = d;
});

View File

@ -22,10 +22,16 @@ void regclass_graph_Shape(py::module m) {
shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths"));
shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths"));
shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths"));
shape.def(
"__eq__",
[](const ov::Shape& a, const ov::Shape& b) {
return a == b;
},
py::is_operator());
shape.def("__len__", [](const ov::Shape& v) {
return v.size();
});
shape.def("__setitem__", [](ov::Shape& self, size_t key, size_t d) {
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension::value_type d) {
self[key] = d;
});
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension d) {

View File

@ -5,7 +5,7 @@ import numpy as np
import pytest
import openvino.runtime.opset8 as ops
from openvino.runtime import Model, Tensor, Output, Dimension,\
from openvino.runtime import Core, Model, Tensor, Output, Dimension,\
Layout, Type, PartialShape, Shape, set_batch, get_batch
@ -349,3 +349,15 @@ def test_reshape_with_names():
for input in model.inputs:
model.reshape({input.any_name: new_shape})
assert input.partial_shape == new_shape
def test_reshape(device):
shape = Shape([1, 10])
param = ops.parameter(shape, dtype=np.float32)
model = Model(ops.relu(param), [param])
ref_shape = model.input().partial_shape
ref_shape[0] = 3
model.reshape(ref_shape)
core = Core()
compiled = core.compile_model(model, device)
assert compiled.input().partial_shape == ref_shape

View File

@ -220,6 +220,16 @@ def test_partial_shape_equals():
shape = Shape([1, 2, 3])
ps = PartialShape([1, 2, 3])
assert shape == ps
assert shape == ps.to_shape()
def test_input_shape_read_only():
shape = Shape([1, 10])
param = ov.parameter(shape, dtype=np.float32)
model = Model(ov.relu(param), [param])
ref_shape = model.input().shape
ref_shape[0] = Dimension(3)
assert model.input().shape == shape
def test_repr_dynamic_shape():