[ nG: TI ] Fix for reshaping TensorIterator (#3237)

* [ nG: TI ] Fix for reshaping TensorIterator

* Trigger CI

* misprint

* style

* style
This commit is contained in:
Evgenya Stepyreva
2020-11-20 15:22:19 +03:00
committed by GitHub
parent 3a9c731c2b
commit a05b7c76b2

View File

@@ -128,28 +128,15 @@ void op::v0::TensorIterator::validate_and_infer_types()
// +1 because the left and right borders are included [start, end]
m_num_iterations = (abs(end - start) + 1) / part_size;
if (body_param_partial_shape.is_static())
{
// validate
auto body_param_shape = body_param_partial_shape.to_shape();
for (auto i = 0; i < input_shape.size(); i++)
{
if (i != axis)
{
NODE_VALIDATION_CHECK(
this,
input_shape[i] == body_param_shape[i],
"Iterator input is not compatible with body param");
}
}
}
else
{
// infer type for m_body_parameter
Shape out_shape{input_shape};
out_shape[axis] = part_size;
body_parameter->set_partial_shape(out_shape);
}
// infer type for m_body_parameter
Shape out_shape{input_shape};
out_shape[axis] = part_size;
body_parameter->set_partial_shape(out_shape);
}
else
{
body_parameter->set_partial_shape(
PartialShape::dynamic(input_partial_shape.rank()));
}
}
else if (auto merged_input_description =
@@ -171,16 +158,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
NODE_VALIDATION_CHECK(this,
input_partial_shape.compatible(body_param_partial_shape),
"Iterator initial value is not compatible with body param");
if (input_partial_shape.is_static())
{
auto input_shape = input_partial_shape.to_shape();
// infer type for body_parameter
if (body_param_partial_shape.is_dynamic())
{
body_parameter->set_partial_shape(input_shape);
}
}
body_parameter->set_partial_shape(input_partial_shape);
}
else if (auto invariant_input_description =
as_type_ptr<InvariantInputDescription>(input_description))
@@ -193,16 +171,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
NODE_VALIDATION_CHECK(this,
input_partial_shape.compatible(body_param_partial_shape),
"Iterator initial value is not compatible with body param");
if (input_partial_shape.is_static())
{
auto input_shape = input_partial_shape.to_shape();
// infer type for m_body_parameter
if (body_param_partial_shape.is_dynamic())
{
body_parameter->set_partial_shape(input_shape);
}
}
body_parameter->set_partial_shape(input_partial_shape);
}
}
@@ -251,6 +220,12 @@ void op::v0::TensorIterator::validate_and_infer_types()
set_output_type(index, body_value.get_element_type(), out_shape);
}
}
else
{
set_output_type(index,
body_value.get_element_type(),
PartialShape::dynamic(body_value.get_partial_shape().rank()));
}
}
else if (auto body_output_description =
as_type_ptr<BodyOutputDescription>(output_description))