Fix shape propagation in TI in case of dyn slice axis (#12926)

This commit is contained in:
Ivan Tikhonov
2022-09-20 23:12:32 +03:00
committed by GitHub
parent e109bd1a32
commit d60a8e89df
2 changed files with 38 additions and 8 deletions

View File

@@ -88,19 +88,20 @@ void op::v0::TensorIterator::validate_and_infer_types() {
auto body_parameter = body->get_parameters().at(slice_input_description->m_body_parameter_index);
auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
auto axis = slice_input_description->m_axis;
if (input_partial_shape.rank().is_static() && input_partial_shape[axis].is_static()) {
if (input_partial_shape.rank().is_static()) {
auto part_size = slice_input_description->m_part_size;
auto dim_size = input_partial_shape[axis].get_length();
auto start = make_positive(slice_input_description->m_start, dim_size);
auto end = make_positive(slice_input_description->m_end, dim_size);
// +1 because the left and right borders are included [start, end]
m_num_iterations = (abs(end - start) + 1) / part_size;
// infer type for m_body_parameter
ov::PartialShape out_shape{input_partial_shape};
out_shape[axis] = part_size;
body_parameter->set_partial_shape(out_shape);
if (input_partial_shape[axis].is_static()) {
auto dim_size = input_partial_shape[axis].get_length();
auto start = make_positive(slice_input_description->m_start, dim_size);
auto end = make_positive(slice_input_description->m_end, dim_size);
// +1 because the left and right borders are included [start, end]
m_num_iterations = (abs(end - start) + 1) / part_size;
}
} else {
body_parameter->set_partial_shape(ov::PartialShape::dynamic(input_partial_shape.rank()));
}

View File

@@ -232,3 +232,32 @@ TEST(type_prop, tensor_iterator_with_dynamic_reshape) {
ASSERT_EQ(tensor_iterator->get_num_iterations(), -1);
}
TEST(type_prop, tensor_iterator_dyn_slice) {
const size_t N = 32; // Batch size
const size_t I = 8; // Input size
ov::PartialShape ps = {N, ov::Dimension::dynamic(), I};
auto SENT = make_shared<op::Parameter>(element::f32, ps);
// Body
auto X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto Res = make_shared<op::Result>(X);
auto body = make_shared<ov::Model>(Res, ParameterVector{X});
auto tensor_iterator = make_shared<op::TensorIterator>();
tensor_iterator->set_body(body);
// start=0, stride=1, part_size=1, end=39, axis=1
const size_t part_size = 1;
tensor_iterator->set_sliced_input(X, SENT, 0, 1, part_size, -1, 1);
// Output 0 is last Ho, result 0 of body
auto out0 = tensor_iterator->get_iter_value(Res, -1);
auto results = ResultVector{make_shared<op::Result>(out0)};
auto model = make_shared<ov::Model>(results, ParameterVector{SENT});
EXPECT_EQ(tensor_iterator->get_num_iterations(), -1);
PartialShape ref_ps = {N, part_size, I};
EXPECT_EQ(X->get_partial_shape(), ref_ps);
}