Corrected reshape paterns
This commit is contained in:
parent
0552dfe537
commit
83f4428f48
@ -163,8 +163,8 @@ bool IsTailFlatten(const Output<Node>& output) {
|
||||
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
||||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
||||
return false;
|
||||
const Shape& input_shape = reshape_node->get_input_shape(0);
|
||||
const Shape& output_shape = reshape_node->get_output_shape(0);
|
||||
const Shape& input_shape = helper::SqueezeShape(reshape_node->get_input_shape(0));
|
||||
const Shape& output_shape = helper::SqueezeShape(reshape_node->get_output_shape(0));
|
||||
return output_shape.size() < input_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
||||
}
|
||||
|
||||
@ -173,8 +173,8 @@ bool IsTailUnflatten(const Output<Node>& output) {
|
||||
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
||||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
||||
return false;
|
||||
const Shape& input_shape = reshape_node->get_input_shape(0);
|
||||
const Shape& output_shape = reshape_node->get_output_shape(0);
|
||||
const Shape& input_shape = helper::SqueezeShape(reshape_node->get_input_shape(0));
|
||||
const Shape& output_shape = helper::SqueezeShape(reshape_node->get_output_shape(0));
|
||||
return output_shape.size() > input_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user