Remove limit for Loop inputs in bindings (#21055)

This commit is contained in:
Pawel Raasz
2023-11-14 11:18:40 +01:00
committed by GitHub
parent b43b9f9d50
commit c974c9a478
2 changed files with 54 additions and 9 deletions

View File

@@ -24,15 +24,7 @@ void regclass_graph_op_Loop(py::module m) {
cls.def(
py::init([](const std::shared_ptr<ov::Node>& trip_count, const std::shared_ptr<ov::Node>& execution_condition) {
if (MultiSubgraphHelpers::is_constant_or_parameter(trip_count) &&
MultiSubgraphHelpers::is_constant_or_parameter(execution_condition)) {
return std::make_shared<ov::op::v5::Loop>(trip_count->output(0), execution_condition->output(0));
} else {
OPENVINO_WARN
<< "Please specify execution_condition and trip_count as Constant or Parameter. Default Loop() "
"constructor was applied.";
return std::make_shared<ov::op::v5::Loop>();
}
return std::make_shared<ov::op::v5::Loop>(trip_count, execution_condition);
}),
py::arg("trip_count"),
py::arg("execution_condition"));

View File

@@ -66,6 +66,59 @@ def test_simple_loop():
assert list(loop.get_output_shape(2)) == out2_shape
def test_loop_inputs_are_nodes():
param_x = ov.parameter(Shape([32, 1, 10]), np.float32, "X")
param_y = ov.parameter(Shape([32, 1, 10]), np.float32, "Y")
param_m = ov.parameter(Shape([32, 1, 10]), np.float32, "M")
input_shape = Shape([])
current_iteration = ov.parameter(Shape([1]), np.int64)
x_i = ov.parameter(input_shape, np.float32)
y_i = ov.parameter(input_shape, np.float32)
m_body = ov.parameter(input_shape, np.float32)
bool_val = np.array([1], dtype=bool)
bool_val[0] = True
body_condition = ov.constant(bool_val)
trip_shape = ov.parameter([10], np.int64, "trip_shapeof")
trip_count = ov.shape_of(trip_shape)
exp_shape_size = ov.constant(10, np.int64)
exec_condition = ov.equal(exp_shape_size, trip_count)
add = ov.add(x_i, y_i)
zo = ov.multiply(add, m_body)
body = Model([body_condition, zo], [current_iteration, x_i, y_i, m_body], "body_function")
loop = ov.loop(trip_count, exec_condition)
loop.set_function(body)
loop.set_invariant_input(x_i, param_x.output(0))
loop.set_invariant_input(y_i, param_y.output(0))
loop.set_merged_input(m_body, param_m.output(0), zo.output(0))
loop.set_special_body_ports([-1, 0])
loop.validate()
out0 = loop.get_iter_value(body_condition.output(0), -1)
out1 = loop.get_iter_value(zo.output(0), -1)
out2 = loop.get_concatenated_slices(zo.output(0), 0, 1, 1, -1, 1)
result0 = ov.result(out0)
result1 = ov.result(out1)
result2 = ov.result(out2)
out0_shape = [1]
out1_shape = [32, 1, 10]
out2_shape = [32, 10, 10]
assert list(result0.get_output_shape(0)) == out0_shape
assert list(result1.get_output_shape(0)) == out1_shape
assert list(result2.get_output_shape(0)) == out2_shape
assert list(loop.get_output_shape(0)) == out0_shape
assert list(loop.get_output_shape(1)) == out1_shape
assert list(loop.get_output_shape(2)) == out2_shape
def test_loop_basic():
bool_val = np.array([1], dtype=bool)
bool_val[0] = True