Remove limit for Loop inputs in bindings (#21055)
This commit is contained in:
@@ -24,15 +24,7 @@ void regclass_graph_op_Loop(py::module m) {
|
||||
|
||||
cls.def(
|
||||
py::init([](const std::shared_ptr<ov::Node>& trip_count, const std::shared_ptr<ov::Node>& execution_condition) {
|
||||
if (MultiSubgraphHelpers::is_constant_or_parameter(trip_count) &&
|
||||
MultiSubgraphHelpers::is_constant_or_parameter(execution_condition)) {
|
||||
return std::make_shared<ov::op::v5::Loop>(trip_count->output(0), execution_condition->output(0));
|
||||
} else {
|
||||
OPENVINO_WARN
|
||||
<< "Please specify execution_condition and trip_count as Constant or Parameter. Default Loop() "
|
||||
"constructor was applied.";
|
||||
return std::make_shared<ov::op::v5::Loop>();
|
||||
}
|
||||
return std::make_shared<ov::op::v5::Loop>(trip_count, execution_condition);
|
||||
}),
|
||||
py::arg("trip_count"),
|
||||
py::arg("execution_condition"));
|
||||
|
||||
@@ -66,6 +66,59 @@ def test_simple_loop():
|
||||
assert list(loop.get_output_shape(2)) == out2_shape
|
||||
|
||||
|
||||
def test_loop_inputs_are_nodes():
|
||||
param_x = ov.parameter(Shape([32, 1, 10]), np.float32, "X")
|
||||
param_y = ov.parameter(Shape([32, 1, 10]), np.float32, "Y")
|
||||
param_m = ov.parameter(Shape([32, 1, 10]), np.float32, "M")
|
||||
|
||||
input_shape = Shape([])
|
||||
|
||||
current_iteration = ov.parameter(Shape([1]), np.int64)
|
||||
x_i = ov.parameter(input_shape, np.float32)
|
||||
y_i = ov.parameter(input_shape, np.float32)
|
||||
m_body = ov.parameter(input_shape, np.float32)
|
||||
bool_val = np.array([1], dtype=bool)
|
||||
bool_val[0] = True
|
||||
body_condition = ov.constant(bool_val)
|
||||
trip_shape = ov.parameter([10], np.int64, "trip_shapeof")
|
||||
trip_count = ov.shape_of(trip_shape)
|
||||
exp_shape_size = ov.constant(10, np.int64)
|
||||
exec_condition = ov.equal(exp_shape_size, trip_count)
|
||||
|
||||
add = ov.add(x_i, y_i)
|
||||
zo = ov.multiply(add, m_body)
|
||||
|
||||
body = Model([body_condition, zo], [current_iteration, x_i, y_i, m_body], "body_function")
|
||||
|
||||
loop = ov.loop(trip_count, exec_condition)
|
||||
loop.set_function(body)
|
||||
loop.set_invariant_input(x_i, param_x.output(0))
|
||||
loop.set_invariant_input(y_i, param_y.output(0))
|
||||
loop.set_merged_input(m_body, param_m.output(0), zo.output(0))
|
||||
loop.set_special_body_ports([-1, 0])
|
||||
loop.validate()
|
||||
|
||||
out0 = loop.get_iter_value(body_condition.output(0), -1)
|
||||
out1 = loop.get_iter_value(zo.output(0), -1)
|
||||
out2 = loop.get_concatenated_slices(zo.output(0), 0, 1, 1, -1, 1)
|
||||
|
||||
result0 = ov.result(out0)
|
||||
result1 = ov.result(out1)
|
||||
result2 = ov.result(out2)
|
||||
|
||||
out0_shape = [1]
|
||||
out1_shape = [32, 1, 10]
|
||||
out2_shape = [32, 10, 10]
|
||||
|
||||
assert list(result0.get_output_shape(0)) == out0_shape
|
||||
assert list(result1.get_output_shape(0)) == out1_shape
|
||||
assert list(result2.get_output_shape(0)) == out2_shape
|
||||
|
||||
assert list(loop.get_output_shape(0)) == out0_shape
|
||||
assert list(loop.get_output_shape(1)) == out1_shape
|
||||
assert list(loop.get_output_shape(2)) == out2_shape
|
||||
|
||||
|
||||
def test_loop_basic():
|
||||
bool_val = np.array([1], dtype=bool)
|
||||
bool_val[0] = True
|
||||
|
||||
Reference in New Issue
Block a user