[ nG: TI ] Fix for reshaping TensorIterator (#3247)
* [ nG: TI ] Fix for reshaping TensorIterator * comments * style
This commit is contained in:
parent
fc1a3ce2f1
commit
4b22a99a69
@ -114,7 +114,6 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
|||||||
{
|
{
|
||||||
auto body_parameter =
|
auto body_parameter =
|
||||||
m_body->get_parameters().at(slice_input_description->m_body_parameter_index);
|
m_body->get_parameters().at(slice_input_description->m_body_parameter_index);
|
||||||
auto body_param_partial_shape = body_parameter->get_partial_shape();
|
|
||||||
auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
|
auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
|
||||||
if (input_partial_shape.is_static())
|
if (input_partial_shape.is_static())
|
||||||
{
|
{
|
||||||
@ -128,28 +127,15 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
|||||||
|
|
||||||
// +1 because the left and right borders are included [start, end]
|
// +1 because the left and right borders are included [start, end]
|
||||||
m_num_iterations = (abs(end - start) + 1) / part_size;
|
m_num_iterations = (abs(end - start) + 1) / part_size;
|
||||||
if (body_param_partial_shape.is_static())
|
// infer type for m_body_parameter
|
||||||
{
|
Shape out_shape{input_shape};
|
||||||
// validate
|
out_shape[axis] = part_size;
|
||||||
auto body_param_shape = body_param_partial_shape.to_shape();
|
body_parameter->set_partial_shape(out_shape);
|
||||||
for (auto i = 0; i < input_shape.size(); i++)
|
}
|
||||||
{
|
else
|
||||||
if (i != axis)
|
{
|
||||||
{
|
body_parameter->set_partial_shape(
|
||||||
NODE_VALIDATION_CHECK(
|
PartialShape::dynamic(input_partial_shape.rank()));
|
||||||
this,
|
|
||||||
input_shape[i] == body_param_shape[i],
|
|
||||||
"Iterator input is not compatible with body param");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// infer type for m_body_parameter
|
|
||||||
Shape out_shape{input_shape};
|
|
||||||
out_shape[axis] = part_size;
|
|
||||||
body_parameter->set_partial_shape(out_shape);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (auto merged_input_description =
|
else if (auto merged_input_description =
|
||||||
@ -171,16 +157,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
|||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
input_partial_shape.compatible(body_param_partial_shape),
|
input_partial_shape.compatible(body_param_partial_shape),
|
||||||
"Iterator initial value is not compatible with body param");
|
"Iterator initial value is not compatible with body param");
|
||||||
|
body_parameter->set_partial_shape(input_partial_shape);
|
||||||
if (input_partial_shape.is_static())
|
|
||||||
{
|
|
||||||
auto input_shape = input_partial_shape.to_shape();
|
|
||||||
// infer type for body_parameter
|
|
||||||
if (body_param_partial_shape.is_dynamic())
|
|
||||||
{
|
|
||||||
body_parameter->set_partial_shape(input_shape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if (auto invariant_input_description =
|
else if (auto invariant_input_description =
|
||||||
as_type_ptr<InvariantInputDescription>(input_description))
|
as_type_ptr<InvariantInputDescription>(input_description))
|
||||||
@ -193,16 +170,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
|||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
input_partial_shape.compatible(body_param_partial_shape),
|
input_partial_shape.compatible(body_param_partial_shape),
|
||||||
"Iterator initial value is not compatible with body param");
|
"Iterator initial value is not compatible with body param");
|
||||||
|
body_parameter->set_partial_shape(input_partial_shape);
|
||||||
if (input_partial_shape.is_static())
|
|
||||||
{
|
|
||||||
auto input_shape = input_partial_shape.to_shape();
|
|
||||||
// infer type for m_body_parameter
|
|
||||||
if (body_param_partial_shape.is_dynamic())
|
|
||||||
{
|
|
||||||
body_parameter->set_partial_shape(input_shape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,6 +219,12 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
|||||||
set_output_type(index, body_value.get_element_type(), out_shape);
|
set_output_type(index, body_value.get_element_type(), out_shape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
set_output_type(index,
|
||||||
|
body_value.get_element_type(),
|
||||||
|
PartialShape::dynamic(body_value.get_partial_shape().rank()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (auto body_output_description =
|
else if (auto body_output_description =
|
||||||
as_type_ptr<BodyOutputDescription>(output_description))
|
as_type_ptr<BodyOutputDescription>(output_description))
|
||||||
|
Loading…
Reference in New Issue
Block a user