[ nG: TI ] Fix for reshaping TensorIterator (#3247)

* [ nG: TI ] Fix for reshaping TensorIterator

* comments

* style
This commit is contained in:
Evgenya Stepyreva 2020-11-20 17:43:21 +03:00 committed by GitHub
parent fc1a3ce2f1
commit 4b22a99a69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -114,7 +114,6 @@ void op::v0::TensorIterator::validate_and_infer_types()
{ {
auto body_parameter = auto body_parameter =
m_body->get_parameters().at(slice_input_description->m_body_parameter_index); m_body->get_parameters().at(slice_input_description->m_body_parameter_index);
auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
if (input_partial_shape.is_static()) if (input_partial_shape.is_static())
{ {
@ -128,28 +127,15 @@ void op::v0::TensorIterator::validate_and_infer_types()
// +1 because the left and right borders are included [start, end] // +1 because the left and right borders are included [start, end]
m_num_iterations = (abs(end - start) + 1) / part_size; m_num_iterations = (abs(end - start) + 1) / part_size;
if (body_param_partial_shape.is_static()) // infer type for m_body_parameter
{ Shape out_shape{input_shape};
// validate out_shape[axis] = part_size;
auto body_param_shape = body_param_partial_shape.to_shape(); body_parameter->set_partial_shape(out_shape);
for (auto i = 0; i < input_shape.size(); i++) }
{ else
if (i != axis) {
{ body_parameter->set_partial_shape(
NODE_VALIDATION_CHECK( PartialShape::dynamic(input_partial_shape.rank()));
this,
input_shape[i] == body_param_shape[i],
"Iterator input is not compatible with body param");
}
}
}
else
{
// infer type for m_body_parameter
Shape out_shape{input_shape};
out_shape[axis] = part_size;
body_parameter->set_partial_shape(out_shape);
}
} }
} }
else if (auto merged_input_description = else if (auto merged_input_description =
@ -171,16 +157,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_partial_shape.compatible(body_param_partial_shape), input_partial_shape.compatible(body_param_partial_shape),
"Iterator initial value is not compatible with body param"); "Iterator initial value is not compatible with body param");
body_parameter->set_partial_shape(input_partial_shape);
if (input_partial_shape.is_static())
{
auto input_shape = input_partial_shape.to_shape();
// infer type for body_parameter
if (body_param_partial_shape.is_dynamic())
{
body_parameter->set_partial_shape(input_shape);
}
}
} }
else if (auto invariant_input_description = else if (auto invariant_input_description =
as_type_ptr<InvariantInputDescription>(input_description)) as_type_ptr<InvariantInputDescription>(input_description))
@ -193,16 +170,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_partial_shape.compatible(body_param_partial_shape), input_partial_shape.compatible(body_param_partial_shape),
"Iterator initial value is not compatible with body param"); "Iterator initial value is not compatible with body param");
body_parameter->set_partial_shape(input_partial_shape);
if (input_partial_shape.is_static())
{
auto input_shape = input_partial_shape.to_shape();
// infer type for m_body_parameter
if (body_param_partial_shape.is_dynamic())
{
body_parameter->set_partial_shape(input_shape);
}
}
} }
} }
@ -251,6 +219,12 @@ void op::v0::TensorIterator::validate_and_infer_types()
set_output_type(index, body_value.get_element_type(), out_shape); set_output_type(index, body_value.get_element_type(), out_shape);
} }
} }
else
{
set_output_type(index,
body_value.get_element_type(),
PartialShape::dynamic(body_value.get_partial_shape().rank()));
}
} }
else if (auto body_output_description = else if (auto body_output_description =
as_type_ptr<BodyOutputDescription>(output_description)) as_type_ptr<BodyOutputDescription>(output_description))