From 6ad627d60c70faa1450d481237b04f206a77eaa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Ko=C5=BCykowski?= Date: Mon, 13 Dec 2021 22:46:37 +0100 Subject: [PATCH] [PYTHON] Add missing index operator bindings (#9088) --- .../src/pyopenvino/graph/axis_vector.cpp | 18 ++++++ .../src/pyopenvino/graph/coordinate.cpp | 18 ++++++ .../src/pyopenvino/graph/coordinate_diff.cpp | 19 ++++++ .../python/src/pyopenvino/graph/strides.cpp | 19 ++++++ .../python/tests/test_ngraph/test_basic.py | 61 +++++++++++++++++++ 5 files changed, 135 insertions(+) diff --git a/src/bindings/python/src/pyopenvino/graph/axis_vector.cpp b/src/bindings/python/src/pyopenvino/graph/axis_vector.cpp index 750c5a40eba..d13ddd42606 100644 --- a/src/bindings/python/src/pyopenvino/graph/axis_vector.cpp +++ b/src/bindings/python/src/pyopenvino/graph/axis_vector.cpp @@ -17,4 +17,22 @@ void regclass_graph_AxisVector(py::module m) { axis_vector.def(py::init&>(), py::arg("axes")); axis_vector.def(py::init&>(), py::arg("axes")); axis_vector.def(py::init(), py::arg("axes")); + axis_vector.def("__setitem__", [](ov::AxisVector& self, size_t key, size_t value) { + self[key] = value; + }); + + axis_vector.def("__getitem__", [](const ov::AxisVector& self, size_t key) { + return self[key]; + }); + + axis_vector.def("__len__", [](const ov::AxisVector& self) { + return self.size(); + }); + + axis_vector.def( + "__iter__", + [](const ov::AxisVector& self) { + return py::make_iterator(self.begin(), self.end()); + }, + py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */ } diff --git a/src/bindings/python/src/pyopenvino/graph/coordinate.cpp b/src/bindings/python/src/pyopenvino/graph/coordinate.cpp index fc5ac13ee0f..78264ea0d31 100644 --- a/src/bindings/python/src/pyopenvino/graph/coordinate.cpp +++ b/src/bindings/python/src/pyopenvino/graph/coordinate.cpp @@ -18,4 +18,22 @@ void regclass_graph_Coordinate(py::module m) { coordinate.def(py::init()); coordinate.def(py::init&>()); coordinate.def(py::init()); + coordinate.def("__setitem__", [](ov::Coordinate& self, size_t key, size_t value) { + self[key] = value; + }); + + coordinate.def("__getitem__", [](const ov::Coordinate& self, size_t key) { + return self[key]; + }); + + coordinate.def("__len__", [](const ov::Coordinate& self) { + return self.size(); + }); + + coordinate.def( + "__iter__", + [](const ov::Coordinate& self) { + return py::make_iterator(self.begin(), self.end()); + }, + py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */ } diff --git a/src/bindings/python/src/pyopenvino/graph/coordinate_diff.cpp b/src/bindings/python/src/pyopenvino/graph/coordinate_diff.cpp index feeb37e1573..ffccead7085 100644 --- a/src/bindings/python/src/pyopenvino/graph/coordinate_diff.cpp +++ b/src/bindings/python/src/pyopenvino/graph/coordinate_diff.cpp @@ -34,4 +34,23 @@ void regclass_graph_CoordinateDiff(py::module m) { std::string shape_str = py::cast(self).attr("__str__")().cast(); return "<" + class_name + ": (" + shape_str + ")>"; }); + + coordinate_diff.def("__setitem__", [](ov::CoordinateDiff& self, size_t key, std::ptrdiff_t& value) { + self[key] = value; + }); + + coordinate_diff.def("__getitem__", [](const ov::CoordinateDiff& self, size_t key) { + return self[key]; + }); + + coordinate_diff.def("__len__", [](const ov::CoordinateDiff& self) { + return self.size(); + }); + + coordinate_diff.def( + "__iter__", + [](const ov::CoordinateDiff& self) { + return py::make_iterator(self.begin(), self.end()); + }, + py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */ } diff --git a/src/bindings/python/src/pyopenvino/graph/strides.cpp b/src/bindings/python/src/pyopenvino/graph/strides.cpp index 7a97e49c2f1..9f5d50e2684 100644 --- a/src/bindings/python/src/pyopenvino/graph/strides.cpp +++ b/src/bindings/python/src/pyopenvino/graph/strides.cpp @@ -34,4 +34,23 @@ void regclass_graph_Strides(py::module m) { std::string shape_str = py::cast(self).attr("__str__")().cast(); return "<" + class_name + ": (" + shape_str + ")>"; }); + + strides.def("__setitem__", [](ov::Strides& self, size_t key, size_t value) { + self[key] = value; + }); + + strides.def("__getitem__", [](const ov::Strides& self, size_t key) { + return self[key]; + }); + + strides.def("__len__", [](const ov::Strides& self) { + return self.size(); + }); + + strides.def( + "__iter__", + [](const ov::Strides& self) { + return py::make_iterator(self.begin(), self.end()); + }, + py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */ } diff --git a/src/bindings/python/tests/test_ngraph/test_basic.py b/src/bindings/python/tests/test_ngraph/test_basic.py index 6a455dbb270..cd887ed4a19 100644 --- a/src/bindings/python/tests/test_ngraph/test_basic.py +++ b/src/bindings/python/tests/test_ngraph/test_basic.py @@ -13,6 +13,7 @@ from openvino.pyopenvino import OVAny from openvino.runtime.exceptions import UserInputError from openvino.runtime import Model, PartialShape, Shape, Type, layout_helpers +from openvino.runtime import Strides, AxisVector, Coordinate, CoordinateDiff from openvino.runtime import Tensor from openvino.pyopenvino import DescriptorTensor from openvino.runtime.op import Parameter @@ -563,6 +564,66 @@ def test_node_version(): assert node.version == 1 +def test_strides_iteration_methods(): + data = np.array([1, 2, 3]) + strides = Strides(data) + + assert len(strides) == data.size + assert np.equal(strides, data).all() + assert np.equal([strides[i] for i in range(data.size)], data).all() + + data2 = np.array([5, 6, 7]) + for i in range(data2.size): + strides[i] = data2[i] + + assert np.equal(strides, data2).all() + + +def test_axis_vector_iteration_methods(): + data = np.array([1, 2, 3]) + axisVector = AxisVector(data) + + assert len(axisVector) == data.size + assert np.equal(axisVector, data).all() + assert np.equal([axisVector[i] for i in range(data.size)], data).all() + + data2 = np.array([5, 6, 7]) + for i in range(data2.size): + axisVector[i] = data2[i] + + assert np.equal(axisVector, data2).all() + + +def test_coordinate_iteration_methods(): + data = np.array([1, 2, 3]) + coordinate = Coordinate(data) + + assert len(coordinate) == data.size + assert np.equal(coordinate, data).all() + assert np.equal([coordinate[i] for i in range(data.size)], data).all() + + data2 = np.array([5, 6, 7]) + for i in range(data2.size): + coordinate[i] = data2[i] + + assert np.equal(coordinate, data2).all() + + +def test_coordinate_diff_iteration_methods(): + data = np.array([1, 2, 3]) + coordinateDiff = CoordinateDiff(data) + + assert len(coordinateDiff) == data.size + assert np.equal(coordinateDiff, data).all() + assert np.equal([coordinateDiff[i] for i in range(data.size)], data).all() + + data2 = np.array([5, 6, 7]) + for i in range(data2.size): + coordinateDiff[i] = data2[i] + + assert np.equal(coordinateDiff, data2).all() + + def test_layout(): layout = ov.Layout("NCWH") layout2 = ov.Layout("NCWH")