Fix TensorIterator dynamic rank output (#10247)

* Fix TensorIterator dynamic rank output

* style
This commit is contained in:
Evgenya Stepyreva
2022-02-10 13:03:16 +03:00
committed by GitHub
parent 3f0e532dce
commit 89c3a18f83
2 changed files with 12 additions and 17 deletions

View File

@@ -134,34 +134,26 @@ void op::v0::TensorIterator::validate_and_infer_types() {
auto body_value = m_bodies[0]->get_results().at(output_description->m_body_value_index)->input_value(0);
if (auto concat_output_description = ov::as_type_ptr<ConcatOutputDescription>(output_description)) {
const auto& body_value_partial_shape = body_value.get_partial_shape();
auto body_value_partial_shape = body_value.get_partial_shape();
const auto& body_value_partial_rank = body_value_partial_shape.rank();
set_output_type(index, body_value.get_element_type(), ov::PartialShape::dynamic());
if (body_value_partial_shape.is_static()) {
auto body_value_shape = body_value_partial_shape.to_shape();
if (body_value_partial_rank.is_static()) {
auto part_size = concat_output_description->m_part_size;
auto axis = concat_output_description->m_axis;
ov::Shape out_shape{body_value_shape};
if (body_value_shape.empty()) {
if (body_value_partial_rank == 0) { // after scalars concatenation we must have 1D output
NODE_VALIDATION_CHECK(this,
axis == 0,
"Axis must be equal to 0 if concatenated output "
"tensor slices are scalars. "
"TensorIterator output index: ",
index);
out_shape = ov::Shape(1);
body_value_partial_shape = ov::PartialShape::dynamic(1);
}
if (m_num_iterations != -1) {
// for simple RNN case where stride is the same as part_size
out_shape[axis] = m_num_iterations * part_size;
set_output_type(index, body_value.get_element_type(), out_shape);
}
} else {
set_output_type(index,
body_value.get_element_type(),
ov::PartialShape::dynamic(body_value.get_partial_shape().rank()));
body_value_partial_shape[axis] =
m_num_iterations != -1 ? m_num_iterations * part_size : ov::Dimension::dynamic();
set_output_type(index, body_value.get_element_type(), body_value_partial_shape);
}
} else if (auto body_output_description = ov::as_type_ptr<BodyOutputDescription>(output_description)) {
set_output_type(index, body_value.get_element_type(), body_value.get_partial_shape());

View File

@@ -57,7 +57,10 @@ void op::v1::Transpose::validate_and_infer_types() {
arg_shape);
set_output_type(0, get_input_element_type(0), ngraph::apply_permutation(arg_shape, permutation));
} else {
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic(arg_shape.rank()));
Rank output_rank = arg_shape.rank();
if (output_rank.is_dynamic() && input_order_shape.is_static() && input_order_shape[0].get_length())
output_rank = input_order_shape[0];
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic(output_rank));
}
NGRAPH_SUPPRESS_DEPRECATED_END
}