Fix TensorIterator dynamic rank output (#10247)
* Fix TensorIterator dynamic rank output * style
This commit is contained in:
committed by
GitHub
parent
3f0e532dce
commit
89c3a18f83
@@ -134,34 +134,26 @@ void op::v0::TensorIterator::validate_and_infer_types() {
|
||||
auto body_value = m_bodies[0]->get_results().at(output_description->m_body_value_index)->input_value(0);
|
||||
|
||||
if (auto concat_output_description = ov::as_type_ptr<ConcatOutputDescription>(output_description)) {
|
||||
const auto& body_value_partial_shape = body_value.get_partial_shape();
|
||||
auto body_value_partial_shape = body_value.get_partial_shape();
|
||||
const auto& body_value_partial_rank = body_value_partial_shape.rank();
|
||||
set_output_type(index, body_value.get_element_type(), ov::PartialShape::dynamic());
|
||||
if (body_value_partial_shape.is_static()) {
|
||||
auto body_value_shape = body_value_partial_shape.to_shape();
|
||||
if (body_value_partial_rank.is_static()) {
|
||||
auto part_size = concat_output_description->m_part_size;
|
||||
auto axis = concat_output_description->m_axis;
|
||||
|
||||
ov::Shape out_shape{body_value_shape};
|
||||
|
||||
if (body_value_shape.empty()) {
|
||||
if (body_value_partial_rank == 0) { // after scalars concatenation we must have 1D output
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis == 0,
|
||||
"Axis must be equal to 0 if concatenated output "
|
||||
"tensor slices are scalars. "
|
||||
"TensorIterator output index: ",
|
||||
index);
|
||||
out_shape = ov::Shape(1);
|
||||
body_value_partial_shape = ov::PartialShape::dynamic(1);
|
||||
}
|
||||
|
||||
if (m_num_iterations != -1) {
|
||||
// for simple RNN case where stride is the same as part_size
|
||||
out_shape[axis] = m_num_iterations * part_size;
|
||||
set_output_type(index, body_value.get_element_type(), out_shape);
|
||||
}
|
||||
} else {
|
||||
set_output_type(index,
|
||||
body_value.get_element_type(),
|
||||
ov::PartialShape::dynamic(body_value.get_partial_shape().rank()));
|
||||
body_value_partial_shape[axis] =
|
||||
m_num_iterations != -1 ? m_num_iterations * part_size : ov::Dimension::dynamic();
|
||||
set_output_type(index, body_value.get_element_type(), body_value_partial_shape);
|
||||
}
|
||||
} else if (auto body_output_description = ov::as_type_ptr<BodyOutputDescription>(output_description)) {
|
||||
set_output_type(index, body_value.get_element_type(), body_value.get_partial_shape());
|
||||
|
||||
@@ -57,7 +57,10 @@ void op::v1::Transpose::validate_and_infer_types() {
|
||||
arg_shape);
|
||||
set_output_type(0, get_input_element_type(0), ngraph::apply_permutation(arg_shape, permutation));
|
||||
} else {
|
||||
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic(arg_shape.rank()));
|
||||
Rank output_rank = arg_shape.rank();
|
||||
if (output_rank.is_dynamic() && input_order_shape.is_static() && input_order_shape[0].get_length())
|
||||
output_rank = input_order_shape[0];
|
||||
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic(output_rank));
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user