[PYTHON API] fix direct access to model input shape (#9562)
* Copy port shape to avoid direct access to shape buffer * Add __eq__ for shape * Add tests * Fix getters * add __setitem__ * Add a note about copy
This commit is contained in:
parent
6840d945c7
commit
4a6575b4b7
@ -78,23 +78,25 @@ void regclass_graph_Output(py::module m, std::string typestring)
|
|||||||
)");
|
)");
|
||||||
output.def("get_shape",
|
output.def("get_shape",
|
||||||
&ov::Output<VT>::get_shape,
|
&ov::Output<VT>::get_shape,
|
||||||
|
py::return_value_policy::copy,
|
||||||
R"(
|
R"(
|
||||||
The shape of the output referred to by this output handle.
|
The shape of the output referred to by this output handle.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
get_shape : Shape
|
get_shape : Shape
|
||||||
Shape of the output.
|
Copy of Shape of the output.
|
||||||
)");
|
)");
|
||||||
output.def("get_partial_shape",
|
output.def("get_partial_shape",
|
||||||
&ov::Output<VT>::get_partial_shape,
|
&ov::Output<VT>::get_partial_shape,
|
||||||
|
py::return_value_policy::copy,
|
||||||
R"(
|
R"(
|
||||||
The partial shape of the output referred to by this output handle.
|
The partial shape of the output referred to by this output handle.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
get_partial_shape : PartialShape
|
get_partial_shape : PartialShape
|
||||||
PartialShape of the output.
|
Copy of PartialShape of the output.
|
||||||
)");
|
)");
|
||||||
output.def("get_target_inputs",
|
output.def("get_target_inputs",
|
||||||
&ov::Output<VT>::get_target_inputs,
|
&ov::Output<VT>::get_target_inputs,
|
||||||
@ -137,8 +139,8 @@ void regclass_graph_Output(py::module m, std::string typestring)
|
|||||||
output.def_property_readonly("any_name", &ov::Output<VT>::get_any_name);
|
output.def_property_readonly("any_name", &ov::Output<VT>::get_any_name);
|
||||||
output.def_property_readonly("names", &ov::Output<VT>::get_names);
|
output.def_property_readonly("names", &ov::Output<VT>::get_names);
|
||||||
output.def_property_readonly("element_type", &ov::Output<VT>::get_element_type);
|
output.def_property_readonly("element_type", &ov::Output<VT>::get_element_type);
|
||||||
output.def_property_readonly("shape", &ov::Output<VT>::get_shape);
|
output.def_property_readonly("shape", &ov::Output<VT>::get_shape, py::return_value_policy::copy);
|
||||||
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape);
|
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape, py::return_value_policy::copy);
|
||||||
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
|
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
|
||||||
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
|
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
|
||||||
output.def_property_readonly("rt_info",
|
output.def_property_readonly("rt_info",
|
||||||
|
@ -196,6 +196,10 @@ void regclass_graph_PartialShape(py::module m) {
|
|||||||
return self.size();
|
return self.size();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension::value_type d) {
|
||||||
|
self[key] = d;
|
||||||
|
});
|
||||||
|
|
||||||
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension& d) {
|
shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension& d) {
|
||||||
self[key] = d;
|
self[key] = d;
|
||||||
});
|
});
|
||||||
|
@ -22,10 +22,16 @@ void regclass_graph_Shape(py::module m) {
|
|||||||
shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths"));
|
shape.def(py::init<const std::initializer_list<size_t>&>(), py::arg("axis_lengths"));
|
||||||
shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths"));
|
shape.def(py::init<const std::vector<size_t>&>(), py::arg("axis_lengths"));
|
||||||
shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths"));
|
shape.def(py::init<const ov::Shape&>(), py::arg("axis_lengths"));
|
||||||
|
shape.def(
|
||||||
|
"__eq__",
|
||||||
|
[](const ov::Shape& a, const ov::Shape& b) {
|
||||||
|
return a == b;
|
||||||
|
},
|
||||||
|
py::is_operator());
|
||||||
shape.def("__len__", [](const ov::Shape& v) {
|
shape.def("__len__", [](const ov::Shape& v) {
|
||||||
return v.size();
|
return v.size();
|
||||||
});
|
});
|
||||||
shape.def("__setitem__", [](ov::Shape& self, size_t key, size_t d) {
|
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension::value_type d) {
|
||||||
self[key] = d;
|
self[key] = d;
|
||||||
});
|
});
|
||||||
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension d) {
|
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension d) {
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import openvino.runtime.opset8 as ops
|
import openvino.runtime.opset8 as ops
|
||||||
from openvino.runtime import Model, Tensor, Output, Dimension,\
|
from openvino.runtime import Core, Model, Tensor, Output, Dimension,\
|
||||||
Layout, Type, PartialShape, Shape, set_batch, get_batch
|
Layout, Type, PartialShape, Shape, set_batch, get_batch
|
||||||
|
|
||||||
|
|
||||||
@ -349,3 +349,15 @@ def test_reshape_with_names():
|
|||||||
for input in model.inputs:
|
for input in model.inputs:
|
||||||
model.reshape({input.any_name: new_shape})
|
model.reshape({input.any_name: new_shape})
|
||||||
assert input.partial_shape == new_shape
|
assert input.partial_shape == new_shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_reshape(device):
|
||||||
|
shape = Shape([1, 10])
|
||||||
|
param = ops.parameter(shape, dtype=np.float32)
|
||||||
|
model = Model(ops.relu(param), [param])
|
||||||
|
ref_shape = model.input().partial_shape
|
||||||
|
ref_shape[0] = 3
|
||||||
|
model.reshape(ref_shape)
|
||||||
|
core = Core()
|
||||||
|
compiled = core.compile_model(model, device)
|
||||||
|
assert compiled.input().partial_shape == ref_shape
|
||||||
|
@ -220,6 +220,16 @@ def test_partial_shape_equals():
|
|||||||
shape = Shape([1, 2, 3])
|
shape = Shape([1, 2, 3])
|
||||||
ps = PartialShape([1, 2, 3])
|
ps = PartialShape([1, 2, 3])
|
||||||
assert shape == ps
|
assert shape == ps
|
||||||
|
assert shape == ps.to_shape()
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_shape_read_only():
|
||||||
|
shape = Shape([1, 10])
|
||||||
|
param = ov.parameter(shape, dtype=np.float32)
|
||||||
|
model = Model(ov.relu(param), [param])
|
||||||
|
ref_shape = model.input().shape
|
||||||
|
ref_shape[0] = Dimension(3)
|
||||||
|
assert model.input().shape == shape
|
||||||
|
|
||||||
|
|
||||||
def test_repr_dynamic_shape():
|
def test_repr_dynamic_shape():
|
||||||
|
Loading…
Reference in New Issue
Block a user