Corrected reshape paterns
This commit is contained in:
parent
0552dfe537
commit
83f4428f48
@ -163,8 +163,8 @@ bool IsTailFlatten(const Output<Node>& output) {
|
|||||||
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
||||||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
||||||
return false;
|
return false;
|
||||||
const Shape& input_shape = reshape_node->get_input_shape(0);
|
const Shape& input_shape = helper::SqueezeShape(reshape_node->get_input_shape(0));
|
||||||
const Shape& output_shape = reshape_node->get_output_shape(0);
|
const Shape& output_shape = helper::SqueezeShape(reshape_node->get_output_shape(0));
|
||||||
return output_shape.size() < input_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
return output_shape.size() < input_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,8 +173,8 @@ bool IsTailUnflatten(const Output<Node>& output) {
|
|||||||
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
||||||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
||||||
return false;
|
return false;
|
||||||
const Shape& input_shape = reshape_node->get_input_shape(0);
|
const Shape& input_shape = helper::SqueezeShape(reshape_node->get_input_shape(0));
|
||||||
const Shape& output_shape = reshape_node->get_output_shape(0);
|
const Shape& output_shape = helper::SqueezeShape(reshape_node->get_output_shape(0));
|
||||||
return output_shape.size() > input_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
return output_shape.size() > input_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user