Handle Reshape's special zero in SimplifySecondInputOfReshape (#20785)
* Handle Reshape's special zero in SimplifySecondInputOfReshape SimplifySecondInputOfReshape detects ShapeOf->Gather->Concat subgraphs on Reshape's second input and replaces ShapeOf->Gather with a Constant with zero(s). Currently it works only with Reshapes that have special_zero set to true, but it can work for Reshapes with special_zero == false if non-Gather inputs to Concat are Constants and don't contain any zero. Ticket: CVS-123434 * fix no default output
This commit is contained in:
parent
e6ff5edc3d
commit
4084e9acc0
@ -200,7 +200,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
matcher_pass_callback callback = [=](Matcher& m) {
|
||||
auto node = m.get_match_root();
|
||||
const auto reshape = as_type_ptr<v1::Reshape>(node);
|
||||
if (!reshape || reshape->get_special_zero() == false) {
|
||||
if (!reshape) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -219,7 +219,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
|
||||
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
|
||||
auto shape_of = gather->get_input_node_shared_ptr(0);
|
||||
if (!is_type<v3::ShapeOf>(shape_of) && !is_type<v0::ShapeOf>(shape_of)) {
|
||||
if (!is_type<op::util::ShapeOfBase>(shape_of)) {
|
||||
return false;
|
||||
}
|
||||
return shape_of->input_value(0) == data;
|
||||
@ -237,16 +237,15 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
gather_dims_expected_location += concat_input_shape[0];
|
||||
};
|
||||
|
||||
bool special_zero = reshape->get_special_zero();
|
||||
|
||||
// We need this check to avoid sequences shapeOf -> gather -> concat
|
||||
// that change the arrangement of dimensions in the reshape pattern
|
||||
for (auto& concat_input : new_concat_inputs) {
|
||||
if (const auto gather = as_type_ptr<op::util::GatherBase>(concat_input.get_node_shared_ptr())) {
|
||||
auto indices_constant = as_type_ptr<v0::Constant>(gather->get_input_node_shared_ptr(1));
|
||||
if (!indices_constant || !check_shape_of_gather(gather)) {
|
||||
update_expected_gather_location(gather);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto node = concat_input.get_node_shared_ptr();
|
||||
if (ov::is_type<op::util::GatherBase>(node) &&
|
||||
ov::is_type<v0::Constant>(node->get_input_node_shared_ptr(1)) && check_shape_of_gather(node)) {
|
||||
auto indices_constant = as_type_ptr<v0::Constant>(node->get_input_node_shared_ptr(1));
|
||||
bool gather_can_be_fused = true;
|
||||
const auto indices = indices_constant->cast_vector<std::int64_t>();
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
@ -258,11 +257,21 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
|
||||
if (gather_can_be_fused) {
|
||||
const size_t num_of_unchanged_dimensions = indices.size();
|
||||
const auto subgraph_et = gather->get_input_element_type(0);
|
||||
const auto subgraph_et = node->get_input_element_type(0);
|
||||
concat_input = v0::Constant::create(subgraph_et, Shape{num_of_unchanged_dimensions}, {0});
|
||||
gather_folded = true;
|
||||
}
|
||||
} else {
|
||||
if (!special_zero) {
|
||||
// If special zero is false - check if other inputs to Concat are Constants.
|
||||
// If any of those Constants contain zero - return false.
|
||||
auto constant = as_type_ptr<v0::Constant>(node);
|
||||
if (!constant)
|
||||
return false;
|
||||
auto values = constant->cast_vector<int64_t>();
|
||||
if (std::find(values.begin(), values.end(), 0) != values.end())
|
||||
return false;
|
||||
}
|
||||
update_expected_gather_location(concat_input);
|
||||
}
|
||||
}
|
||||
@ -275,7 +284,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
new_concat->set_friendly_name(concat->get_friendly_name());
|
||||
copy_runtime_info(concat, new_concat);
|
||||
|
||||
const auto new_reshape = reshape->clone_with_new_inputs({reshape->input_value(0), new_concat});
|
||||
const auto new_reshape = std::make_shared<v1::Reshape>(reshape->input_value(0), new_concat, true);
|
||||
new_reshape->set_friendly_name(reshape->get_friendly_name());
|
||||
|
||||
copy_runtime_info(reshape, new_reshape);
|
||||
|
@ -611,3 +611,53 @@ TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest21) {
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZero) {
|
||||
PartialShape data_shape{1, 128, 12, 64};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::ATTRIBUTES);
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZeroZeroDim) {
|
||||
PartialShape data_shape{1, 0, 12, 64};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::ATTRIBUTES);
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user