Fix shape propagation in TI in case of dyn slice axis (#12926)
This commit is contained in:
@@ -88,19 +88,20 @@ void op::v0::TensorIterator::validate_and_infer_types() {
|
||||
auto body_parameter = body->get_parameters().at(slice_input_description->m_body_parameter_index);
|
||||
auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
|
||||
auto axis = slice_input_description->m_axis;
|
||||
if (input_partial_shape.rank().is_static() && input_partial_shape[axis].is_static()) {
|
||||
if (input_partial_shape.rank().is_static()) {
|
||||
auto part_size = slice_input_description->m_part_size;
|
||||
|
||||
auto dim_size = input_partial_shape[axis].get_length();
|
||||
auto start = make_positive(slice_input_description->m_start, dim_size);
|
||||
auto end = make_positive(slice_input_description->m_end, dim_size);
|
||||
|
||||
// +1 because the left and right borders are included [start, end]
|
||||
m_num_iterations = (abs(end - start) + 1) / part_size;
|
||||
// infer type for m_body_parameter
|
||||
ov::PartialShape out_shape{input_partial_shape};
|
||||
out_shape[axis] = part_size;
|
||||
body_parameter->set_partial_shape(out_shape);
|
||||
if (input_partial_shape[axis].is_static()) {
|
||||
auto dim_size = input_partial_shape[axis].get_length();
|
||||
auto start = make_positive(slice_input_description->m_start, dim_size);
|
||||
auto end = make_positive(slice_input_description->m_end, dim_size);
|
||||
|
||||
// +1 because the left and right borders are included [start, end]
|
||||
m_num_iterations = (abs(end - start) + 1) / part_size;
|
||||
}
|
||||
} else {
|
||||
body_parameter->set_partial_shape(ov::PartialShape::dynamic(input_partial_shape.rank()));
|
||||
}
|
||||
|
||||
@@ -232,3 +232,32 @@ TEST(type_prop, tensor_iterator_with_dynamic_reshape) {
|
||||
|
||||
ASSERT_EQ(tensor_iterator->get_num_iterations(), -1);
|
||||
}
|
||||
|
||||
TEST(type_prop, tensor_iterator_dyn_slice) {
|
||||
const size_t N = 32; // Batch size
|
||||
const size_t I = 8; // Input size
|
||||
|
||||
ov::PartialShape ps = {N, ov::Dimension::dynamic(), I};
|
||||
auto SENT = make_shared<op::Parameter>(element::f32, ps);
|
||||
|
||||
// Body
|
||||
auto X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Res = make_shared<op::Result>(X);
|
||||
auto body = make_shared<ov::Model>(Res, ParameterVector{X});
|
||||
auto tensor_iterator = make_shared<op::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
// start=0, stride=1, part_size=1, end=39, axis=1
|
||||
const size_t part_size = 1;
|
||||
tensor_iterator->set_sliced_input(X, SENT, 0, 1, part_size, -1, 1);
|
||||
|
||||
// Output 0 is last Ho, result 0 of body
|
||||
auto out0 = tensor_iterator->get_iter_value(Res, -1);
|
||||
|
||||
auto results = ResultVector{make_shared<op::Result>(out0)};
|
||||
auto model = make_shared<ov::Model>(results, ParameterVector{SENT});
|
||||
|
||||
EXPECT_EQ(tensor_iterator->get_num_iterations(), -1);
|
||||
PartialShape ref_ps = {N, part_size, I};
|
||||
EXPECT_EQ(X->get_partial_shape(), ref_ps);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user