Corrected reshape paterns

This commit is contained in:
Mikhail Ryzhov 2023-03-20 18:01:35 +01:00
parent 0552dfe537
commit 83f4428f48

View File

@ -163,8 +163,8 @@ bool IsTailFlatten(const Output<Node>& output) {
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
return false;
const Shape& input_shape = reshape_node->get_input_shape(0);
const Shape& output_shape = reshape_node->get_output_shape(0);
const Shape& input_shape = helper::SqueezeShape(reshape_node->get_input_shape(0));
const Shape& output_shape = helper::SqueezeShape(reshape_node->get_output_shape(0));
return output_shape.size() < input_shape.size() && AreFlattenShapes(input_shape, output_shape);
}
@ -173,8 +173,8 @@ bool IsTailUnflatten(const Output<Node>& output) {
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
return false;
const Shape& input_shape = reshape_node->get_input_shape(0);
const Shape& output_shape = reshape_node->get_output_shape(0);
const Shape& input_shape = helper::SqueezeShape(reshape_node->get_input_shape(0));
const Shape& output_shape = helper::SqueezeShape(reshape_node->get_output_shape(0));
return output_shape.size() > input_shape.size() && AreFlattenShapes(input_shape, output_shape);
}
} // namespace