Reducing ShapeOf sub-graphs for Flatten-like Reshapes (#21365)

* Extends SharedOpsOptimization with ScatterUpdate support

* Introduces Flatten operation symbolic optimization

* Code Style: tests

* Test fix

* Refactor
This commit is contained in:
Evgenya Nugmanova 2023-11-30 12:21:18 +04:00 committed by GitHub
parent 51b5bc5ec4
commit 243898560f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 150 additions and 0 deletions

View File

@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
namespace ov {
namespace pass {
class TRANSFORMATIONS_API ReshapeOptimizations;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief Searches for Flatten-like Reshape operations and simplifies 2nd input of such Reshape using special zero
* feature
*/
class ov::pass::ReshapeOptimizations : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ReshapeOptimizations", "0");
ReshapeOptimizations();
};

View File

@ -11,6 +11,7 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/gather_elements.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
@ -195,6 +196,7 @@ bool pass::SharedOpOptimization::run_on_model(const shared_ptr<Model>& model) {
RECORD_NO_ATTRIBUTES(v0::Squeeze),
RECORD_NO_ATTRIBUTES(v0::Tile),
RECORD_NO_ATTRIBUTES(v0::Unsqueeze),
RECORD_NO_ATTRIBUTES(v3::ScatterUpdate),
// with attributes
RECORD(v0::Concat, concats_are_equal),

View File

@ -0,0 +1,54 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/symbolic_transformations/reshape_optimizations.hpp"
#include <openvino/core/dimension_tracker.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include "compare.hpp"
#include "itt.hpp"
#include "transformations/symbolic_transformations/utils.hpp"
#include "transformations/utils/utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::op;
using namespace ov::symbol::util;
ov::pass::ReshapeOptimizations::ReshapeOptimizations() {
MATCHER_SCOPE(ReshapeOptimizations);
auto data_label = pattern::any_input(pattern::has_static_rank());
auto pattern_label = pattern::any_input([](const Output<Node>& output) {
return pattern::has_static_shape() && !ov::as_type_ptr<v0::Constant>(output.get_node_shared_ptr());
});
auto reshape_label = ov::pass::pattern::wrap_type<op::v1::Reshape>({data_label, pattern_label});
ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) {
const auto& reshape = ov::as_type_ptr<v1::Reshape>(m.get_match_root());
if (!reshape)
return false;
const auto& in_shape = reshape->get_input_partial_shape(0);
const auto& out_shape = reshape->get_output_partial_shape(0);
if (in_shape.size() > out_shape.size()) {
std::vector<int64_t> output_pattern(out_shape.size(), -1);
for (size_t i = 0; i < out_shape.size(); ++i)
if (dims_are_equal(in_shape[i], out_shape[i]))
output_pattern[i] = 0;
if (std::count(output_pattern.begin(), output_pattern.end(), -1) == 1) {
auto new_pattern =
ov::op::v0::Constant::create(element::i64, Shape{output_pattern.size()}, output_pattern);
ov::copy_runtime_info(reshape->get_input_node_shared_ptr(1), new_pattern);
reshape->set_special_zero(true);
reshape->input(1).replace_source_output(new_pattern->output(0));
return true;
}
}
return false;
};
auto m = std::make_shared<pattern::Matcher>(reshape_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -22,6 +22,7 @@
#include "transformations/symbolic_transformations/dereshape_matmul.hpp"
#include "transformations/symbolic_transformations/label_optimization.hpp"
#include "transformations/symbolic_transformations/nop_broadcast.hpp"
#include "transformations/symbolic_transformations/reshape_optimizations.hpp"
#include "transformations/symbolic_transformations/utils.hpp"
using namespace ov::pass;
@ -201,6 +202,7 @@ ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
REGISTER_SYMBOLIC(OptimizeLabelsUsedAsValues) // reduce shape sub-graphs
REGISTER_SYMBOLIC(LabelResolvingThroughSelect) // figures out that broadcasting didn't happen through Select op
REGISTER_SYMBOLIC(DeReshapeMatMul)
REGISTER_SYMBOLIC(ReshapeOptimizations)
REGISTER_SYMBOLIC(SimplifyShapeOfSubGraph)
}
}

View File

@ -0,0 +1,67 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/symbolic_transformations/reshape_optimizations.hpp"
#include <gtest/gtest.h>
#include <openvino/core/model.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/shape_of.hpp>
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/dimension_tracker.hpp"
using namespace ov;
using namespace ov::op;
using namespace std;
namespace {
void label_shape(ov::PartialShape& shape) {
auto table = std::make_shared<ov::TableOfEquivalence>(42);
auto tracker = ov::DimensionTracker(table);
tracker.set_up_for_tracking(shape);
}
} // namespace
TEST_F(TransformationTestsF, FlattenOptimization) {
{
auto shape = PartialShape::dynamic(4);
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44, 45
auto data = make_shared<v0::Parameter>(element::f32, shape);
auto shape_of = make_shared<v3::ShapeOf>(data);
auto indices = ov::op::v0::Constant::create(element::i64, {2}, {0, 1});
auto axis = ov::op::v0::Constant::create(element::i64, {}, {0});
auto as_is_dims = make_shared<v1::Gather>(shape_of, indices, axis);
auto merged_dim = make_shared<v1::Multiply>(
make_shared<v1::Gather>(shape_of, ov::op::v0::Constant::create(element::i64, {1}, {2}), axis),
make_shared<v1::Gather>(shape_of, ov::op::v0::Constant::create(element::i64, {1}, {3}), axis));
auto pattern = make_shared<v0::Concat>(OutputVector{as_is_dims, merged_dim}, 0);
auto reshape = make_shared<v1::Reshape>(data, pattern, false);
model = make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<pass::ReshapeOptimizations>();
}
{
auto shape = PartialShape::dynamic(4);
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44, 45
auto data = make_shared<v0::Parameter>(element::f32, shape);
auto pattern = ov::op::v0::Constant::create(element::i64, {3}, {0, 0, -1});
auto reshape = make_shared<v1::Reshape>(data, pattern, true);
model_ref = make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
}