[ nG: TI ] Fix for reshaping TensorIterator (#3247)
* [ nG: TI ] Fix for reshaping TensorIterator * comments * style
This commit is contained in:
parent
fc1a3ce2f1
commit
4b22a99a69
@ -114,7 +114,6 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
||||
{
|
||||
auto body_parameter =
|
||||
m_body->get_parameters().at(slice_input_description->m_body_parameter_index);
|
||||
auto body_param_partial_shape = body_parameter->get_partial_shape();
|
||||
auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
|
||||
if (input_partial_shape.is_static())
|
||||
{
|
||||
@ -128,28 +127,15 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
||||
|
||||
// +1 because the left and right borders are included [start, end]
|
||||
m_num_iterations = (abs(end - start) + 1) / part_size;
|
||||
if (body_param_partial_shape.is_static())
|
||||
{
|
||||
// validate
|
||||
auto body_param_shape = body_param_partial_shape.to_shape();
|
||||
for (auto i = 0; i < input_shape.size(); i++)
|
||||
{
|
||||
if (i != axis)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
input_shape[i] == body_param_shape[i],
|
||||
"Iterator input is not compatible with body param");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// infer type for m_body_parameter
|
||||
Shape out_shape{input_shape};
|
||||
out_shape[axis] = part_size;
|
||||
body_parameter->set_partial_shape(out_shape);
|
||||
}
|
||||
// infer type for m_body_parameter
|
||||
Shape out_shape{input_shape};
|
||||
out_shape[axis] = part_size;
|
||||
body_parameter->set_partial_shape(out_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
body_parameter->set_partial_shape(
|
||||
PartialShape::dynamic(input_partial_shape.rank()));
|
||||
}
|
||||
}
|
||||
else if (auto merged_input_description =
|
||||
@ -171,16 +157,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_partial_shape.compatible(body_param_partial_shape),
|
||||
"Iterator initial value is not compatible with body param");
|
||||
|
||||
if (input_partial_shape.is_static())
|
||||
{
|
||||
auto input_shape = input_partial_shape.to_shape();
|
||||
// infer type for body_parameter
|
||||
if (body_param_partial_shape.is_dynamic())
|
||||
{
|
||||
body_parameter->set_partial_shape(input_shape);
|
||||
}
|
||||
}
|
||||
body_parameter->set_partial_shape(input_partial_shape);
|
||||
}
|
||||
else if (auto invariant_input_description =
|
||||
as_type_ptr<InvariantInputDescription>(input_description))
|
||||
@ -193,16 +170,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_partial_shape.compatible(body_param_partial_shape),
|
||||
"Iterator initial value is not compatible with body param");
|
||||
|
||||
if (input_partial_shape.is_static())
|
||||
{
|
||||
auto input_shape = input_partial_shape.to_shape();
|
||||
// infer type for m_body_parameter
|
||||
if (body_param_partial_shape.is_dynamic())
|
||||
{
|
||||
body_parameter->set_partial_shape(input_shape);
|
||||
}
|
||||
}
|
||||
body_parameter->set_partial_shape(input_partial_shape);
|
||||
}
|
||||
}
|
||||
|
||||
@ -251,6 +219,12 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
||||
set_output_type(index, body_value.get_element_type(), out_shape);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(index,
|
||||
body_value.get_element_type(),
|
||||
PartialShape::dynamic(body_value.get_partial_shape().rank()));
|
||||
}
|
||||
}
|
||||
else if (auto body_output_description =
|
||||
as_type_ptr<BodyOutputDescription>(output_description))
|
||||
|
Loading…
Reference in New Issue
Block a user