From d60a8e89df649ada0e4b7f5dec08112ef0a862db Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Tue, 20 Sep 2022 23:12:32 +0300 Subject: [PATCH] Fix shape propagation in TI in case of dyn slice axis (#12926) --- src/core/src/op/tensor_iterator.cpp | 17 ++++++------ src/core/tests/type_prop/tensor_iterator.cpp | 29 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/core/src/op/tensor_iterator.cpp b/src/core/src/op/tensor_iterator.cpp index 118f5bf67f2..95f17e13284 100644 --- a/src/core/src/op/tensor_iterator.cpp +++ b/src/core/src/op/tensor_iterator.cpp @@ -88,19 +88,20 @@ void op::v0::TensorIterator::validate_and_infer_types() { auto body_parameter = body->get_parameters().at(slice_input_description->m_body_parameter_index); auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); auto axis = slice_input_description->m_axis; - if (input_partial_shape.rank().is_static() && input_partial_shape[axis].is_static()) { + if (input_partial_shape.rank().is_static()) { auto part_size = slice_input_description->m_part_size; - - auto dim_size = input_partial_shape[axis].get_length(); - auto start = make_positive(slice_input_description->m_start, dim_size); - auto end = make_positive(slice_input_description->m_end, dim_size); - - // +1 because the left and right borders are included [start, end] - m_num_iterations = (abs(end - start) + 1) / part_size; // infer type for m_body_parameter ov::PartialShape out_shape{input_partial_shape}; out_shape[axis] = part_size; body_parameter->set_partial_shape(out_shape); + if (input_partial_shape[axis].is_static()) { + auto dim_size = input_partial_shape[axis].get_length(); + auto start = make_positive(slice_input_description->m_start, dim_size); + auto end = make_positive(slice_input_description->m_end, dim_size); + + // +1 because the left and right borders are included [start, end] + m_num_iterations = (abs(end - start) + 1) / part_size; + } } else { body_parameter->set_partial_shape(ov::PartialShape::dynamic(input_partial_shape.rank())); } diff --git a/src/core/tests/type_prop/tensor_iterator.cpp b/src/core/tests/type_prop/tensor_iterator.cpp index da687725547..dca1e1b63cb 100644 --- a/src/core/tests/type_prop/tensor_iterator.cpp +++ b/src/core/tests/type_prop/tensor_iterator.cpp @@ -232,3 +232,32 @@ TEST(type_prop, tensor_iterator_with_dynamic_reshape) { ASSERT_EQ(tensor_iterator->get_num_iterations(), -1); } + +TEST(type_prop, tensor_iterator_dyn_slice) { + const size_t N = 32; // Batch size + const size_t I = 8; // Input size + + ov::PartialShape ps = {N, ov::Dimension::dynamic(), I}; + auto SENT = make_shared(element::f32, ps); + + // Body + auto X = make_shared(element::f32, PartialShape::dynamic()); + auto Res = make_shared(X); + auto body = make_shared(Res, ParameterVector{X}); + auto tensor_iterator = make_shared(); + tensor_iterator->set_body(body); + + // start=0, stride=1, part_size=1, end=39, axis=1 + const size_t part_size = 1; + tensor_iterator->set_sliced_input(X, SENT, 0, 1, part_size, -1, 1); + + // Output 0 is last Ho, result 0 of body + auto out0 = tensor_iterator->get_iter_value(Res, -1); + + auto results = ResultVector{make_shared(out0)}; + auto model = make_shared(results, ParameterVector{SENT}); + + EXPECT_EQ(tensor_iterator->get_num_iterations(), -1); + PartialShape ref_ps = {N, part_size, I}; + EXPECT_EQ(X->get_partial_shape(), ref_ps); +}