From c974c9a4780b08cf2c53a478fc39d5910258ff34 Mon Sep 17 00:00:00 2001 From: Pawel Raasz Date: Tue, 14 Nov 2023 11:18:40 +0100 Subject: [PATCH] Remove limit for Loop inputs in bindings (#21055) --- .../python/src/pyopenvino/graph/ops/loop.cpp | 10 +--- .../python/tests/test_graph/test_loop.py | 53 +++++++++++++++++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp b/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp index ebf317d3c46..8cf9e61e7cd 100644 --- a/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp +++ b/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp @@ -24,15 +24,7 @@ void regclass_graph_op_Loop(py::module m) { cls.def( py::init([](const std::shared_ptr& trip_count, const std::shared_ptr& execution_condition) { - if (MultiSubgraphHelpers::is_constant_or_parameter(trip_count) && - MultiSubgraphHelpers::is_constant_or_parameter(execution_condition)) { - return std::make_shared(trip_count->output(0), execution_condition->output(0)); - } else { - OPENVINO_WARN - << "Please specify execution_condition and trip_count as Constant or Parameter. Default Loop() " - "constructor was applied."; - return std::make_shared(); - } + return std::make_shared(trip_count, execution_condition); }), py::arg("trip_count"), py::arg("execution_condition")); diff --git a/src/bindings/python/tests/test_graph/test_loop.py b/src/bindings/python/tests/test_graph/test_loop.py index 235ea917ba5..e931ba12f8f 100644 --- a/src/bindings/python/tests/test_graph/test_loop.py +++ b/src/bindings/python/tests/test_graph/test_loop.py @@ -66,6 +66,59 @@ def test_simple_loop(): assert list(loop.get_output_shape(2)) == out2_shape +def test_loop_inputs_are_nodes(): + param_x = ov.parameter(Shape([32, 1, 10]), np.float32, "X") + param_y = ov.parameter(Shape([32, 1, 10]), np.float32, "Y") + param_m = ov.parameter(Shape([32, 1, 10]), np.float32, "M") + + input_shape = Shape([]) + + current_iteration = ov.parameter(Shape([1]), np.int64) + x_i = ov.parameter(input_shape, np.float32) + y_i = ov.parameter(input_shape, np.float32) + m_body = ov.parameter(input_shape, np.float32) + bool_val = np.array([1], dtype=bool) + bool_val[0] = True + body_condition = ov.constant(bool_val) + trip_shape = ov.parameter([10], np.int64, "trip_shapeof") + trip_count = ov.shape_of(trip_shape) + exp_shape_size = ov.constant(10, np.int64) + exec_condition = ov.equal(exp_shape_size, trip_count) + + add = ov.add(x_i, y_i) + zo = ov.multiply(add, m_body) + + body = Model([body_condition, zo], [current_iteration, x_i, y_i, m_body], "body_function") + + loop = ov.loop(trip_count, exec_condition) + loop.set_function(body) + loop.set_invariant_input(x_i, param_x.output(0)) + loop.set_invariant_input(y_i, param_y.output(0)) + loop.set_merged_input(m_body, param_m.output(0), zo.output(0)) + loop.set_special_body_ports([-1, 0]) + loop.validate() + + out0 = loop.get_iter_value(body_condition.output(0), -1) + out1 = loop.get_iter_value(zo.output(0), -1) + out2 = loop.get_concatenated_slices(zo.output(0), 0, 1, 1, -1, 1) + + result0 = ov.result(out0) + result1 = ov.result(out1) + result2 = ov.result(out2) + + out0_shape = [1] + out1_shape = [32, 1, 10] + out2_shape = [32, 10, 10] + + assert list(result0.get_output_shape(0)) == out0_shape + assert list(result1.get_output_shape(0)) == out1_shape + assert list(result2.get_output_shape(0)) == out2_shape + + assert list(loop.get_output_shape(0)) == out0_shape + assert list(loop.get_output_shape(1)) == out1_shape + assert list(loop.get_output_shape(2)) == out2_shape + + def test_loop_basic(): bool_val = np.array([1], dtype=bool) bool_val[0] = True