[PYTHON API] fix direct access to model input shape (#9562)
* Copy port shape to avoid direct access to shape buffer * Add __eq__ for shape * Add tests * Fix getters * add __setitem__ * Add a note about copy
This commit is contained in:
parent
6840d945c7
commit
4a6575b4b7
@ -78,23 +78,25 @@ void regclass_graph_Output(py::module m, std::string typestring)
|
||||
)");
|
||||
output.def("get_shape",
|
||||
&ov::Output<VT>::get_shape,
|
||||
py::return_value_policy::copy,
|
||||
R"(
|
||||
The shape of the output referred to by this output handle.
|
||||
|
||||
Returns
|
||||
----------
|
||||
get_shape : Shape
|
||||
Shape of the output.
|
||||
Copy of Shape of the output.
|
||||
)");
|
||||
output.def("get_partial_shape",
|
||||
&ov::Output<VT>::get_partial_shape,
|
||||
py::return_value_policy::copy,
|
||||
R"(
|
||||
The partial shape of the output referred to by this output handle.
|
||||
|
||||
Returns
|
||||
----------
|
||||
get_partial_shape : PartialShape
|
||||
PartialShape of the output.
|
||||
Copy of PartialShape of the output.
|
||||
)");
|
||||
output.def("get_target_inputs",
|
||||
&ov::Output<VT>::get_target_inputs,
|
||||
@ -137,8 +139,8 @@ void regclass_graph_Output(py::module m, std::string typestring)
|
||||
output.def_property_readonly("any_name", &ov::Output<VT>::get_any_name);
|
||||
output.def_property_readonly("names", &ov::Output<VT>::get_names);
|
||||
output.def_property_readonly("element_type", &ov::Output<VT>::get_element_type);
|
||||
output.def_property_readonly("shape", &ov::Output<VT>::get_shape);
|
||||
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape);
|
||||
output.def_property_readonly("shape", &ov::Output<VT>::get_shape, py::return_value_policy::copy);
|
||||
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape, py::return_value_policy::copy);
|
||||
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
|
||||
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
|
||||
output.def_property_readonly("rt_info",
|
||||
|
@ -196,6 +196,10 @@ void regclass_graph_PartialShape(py::module m) {
|
||||
return self.size();
|
||||
});
|
||||
|
||||
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension::value_type d) {
|
||||
self[key] = d;
|
||||
});
|
||||
|
||||
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension& d) {
|
||||
self[key] = d;
|
||||
});
|
||||
|
@ -22,10 +22,16 @@ void regclass_graph_Shape(py::module m) {
|
||||
shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths"));
|
||||
shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths"));
|
||||
shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths"));
|
||||
shape.def(
|
||||
"__eq__",
|
||||
[](const ov::Shape& a, const ov::Shape& b) {
|
||||
return a == b;
|
||||
},
|
||||
py::is_operator());
|
||||
shape.def("__len__", [](const ov::Shape& v) {
|
||||
return v.size();
|
||||
});
|
||||
shape.def("__setitem__", [](ov::Shape& self, size_t key, size_t d) {
|
||||
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension::value_type d) {
|
||||
self[key] = d;
|
||||
});
|
||||
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension d) {
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import openvino.runtime.opset8 as ops
|
||||
from openvino.runtime import Model, Tensor, Output, Dimension,\
|
||||
from openvino.runtime import Core, Model, Tensor, Output, Dimension,\
|
||||
Layout, Type, PartialShape, Shape, set_batch, get_batch
|
||||
|
||||
|
||||
@ -349,3 +349,15 @@ def test_reshape_with_names():
|
||||
for input in model.inputs:
|
||||
model.reshape({input.any_name: new_shape})
|
||||
assert input.partial_shape == new_shape
|
||||
|
||||
|
||||
def test_reshape(device):
|
||||
shape = Shape([1, 10])
|
||||
param = ops.parameter(shape, dtype=np.float32)
|
||||
model = Model(ops.relu(param), [param])
|
||||
ref_shape = model.input().partial_shape
|
||||
ref_shape[0] = 3
|
||||
model.reshape(ref_shape)
|
||||
core = Core()
|
||||
compiled = core.compile_model(model, device)
|
||||
assert compiled.input().partial_shape == ref_shape
|
||||
|
@ -220,6 +220,16 @@ def test_partial_shape_equals():
|
||||
shape = Shape([1, 2, 3])
|
||||
ps = PartialShape([1, 2, 3])
|
||||
assert shape == ps
|
||||
assert shape == ps.to_shape()
|
||||
|
||||
|
||||
def test_input_shape_read_only():
|
||||
shape = Shape([1, 10])
|
||||
param = ov.parameter(shape, dtype=np.float32)
|
||||
model = Model(ov.relu(param), [param])
|
||||
ref_shape = model.input().shape
|
||||
ref_shape[0] = Dimension(3)
|
||||
assert model.input().shape == shape
|
||||
|
||||
|
||||
def test_repr_dynamic_shape():
|
||||
|
Loading…
Reference in New Issue
Block a user