Reducing ShapeOf sub-graphs for Flatten-like Reshapes (#21365)
* Extends SharedOpsOptimization with ScatterUpdate support * Introduces Flatten operation symbolic optimization * Code Style: tests * Test fix * Refactor
This commit is contained in:
parent
51b5bc5ec4
commit
243898560f
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class TRANSFORMATIONS_API ReshapeOptimizations;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Searches for Flatten-like Reshape operations and simplifies 2nd input of such Reshape using special zero
|
||||
* feature
|
||||
*/
|
||||
class ov::pass::ReshapeOptimizations : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ReshapeOptimizations", "0");
|
||||
ReshapeOptimizations();
|
||||
};
|
@ -11,6 +11,7 @@
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/gather_elements.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/scatter_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
@ -195,6 +196,7 @@ bool pass::SharedOpOptimization::run_on_model(const shared_ptr<Model>& model) {
|
||||
RECORD_NO_ATTRIBUTES(v0::Squeeze),
|
||||
RECORD_NO_ATTRIBUTES(v0::Tile),
|
||||
RECORD_NO_ATTRIBUTES(v0::Unsqueeze),
|
||||
RECORD_NO_ATTRIBUTES(v3::ScatterUpdate),
|
||||
|
||||
// with attributes
|
||||
RECORD(v0::Concat, concats_are_equal),
|
||||
|
@ -0,0 +1,54 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/symbolic_transformations/reshape_optimizations.hpp"
|
||||
|
||||
#include <openvino/core/dimension_tracker.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "compare.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "transformations/symbolic_transformations/utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::op;
|
||||
using namespace ov::symbol::util;
|
||||
|
||||
ov::pass::ReshapeOptimizations::ReshapeOptimizations() {
|
||||
MATCHER_SCOPE(ReshapeOptimizations);
|
||||
auto data_label = pattern::any_input(pattern::has_static_rank());
|
||||
auto pattern_label = pattern::any_input([](const Output<Node>& output) {
|
||||
return pattern::has_static_shape() && !ov::as_type_ptr<v0::Constant>(output.get_node_shared_ptr());
|
||||
});
|
||||
auto reshape_label = ov::pass::pattern::wrap_type<op::v1::Reshape>({data_label, pattern_label});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) {
|
||||
const auto& reshape = ov::as_type_ptr<v1::Reshape>(m.get_match_root());
|
||||
if (!reshape)
|
||||
return false;
|
||||
const auto& in_shape = reshape->get_input_partial_shape(0);
|
||||
const auto& out_shape = reshape->get_output_partial_shape(0);
|
||||
|
||||
if (in_shape.size() > out_shape.size()) {
|
||||
std::vector<int64_t> output_pattern(out_shape.size(), -1);
|
||||
for (size_t i = 0; i < out_shape.size(); ++i)
|
||||
if (dims_are_equal(in_shape[i], out_shape[i]))
|
||||
output_pattern[i] = 0;
|
||||
if (std::count(output_pattern.begin(), output_pattern.end(), -1) == 1) {
|
||||
auto new_pattern =
|
||||
ov::op::v0::Constant::create(element::i64, Shape{output_pattern.size()}, output_pattern);
|
||||
ov::copy_runtime_info(reshape->get_input_node_shared_ptr(1), new_pattern);
|
||||
reshape->set_special_zero(true);
|
||||
reshape->input(1).replace_source_output(new_pattern->output(0));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(reshape_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -22,6 +22,7 @@
|
||||
#include "transformations/symbolic_transformations/dereshape_matmul.hpp"
|
||||
#include "transformations/symbolic_transformations/label_optimization.hpp"
|
||||
#include "transformations/symbolic_transformations/nop_broadcast.hpp"
|
||||
#include "transformations/symbolic_transformations/reshape_optimizations.hpp"
|
||||
#include "transformations/symbolic_transformations/utils.hpp"
|
||||
|
||||
using namespace ov::pass;
|
||||
@ -201,6 +202,7 @@ ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
|
||||
REGISTER_SYMBOLIC(OptimizeLabelsUsedAsValues) // reduce shape sub-graphs
|
||||
REGISTER_SYMBOLIC(LabelResolvingThroughSelect) // figures out that broadcasting didn't happen through Select op
|
||||
REGISTER_SYMBOLIC(DeReshapeMatMul)
|
||||
REGISTER_SYMBOLIC(ReshapeOptimizations)
|
||||
REGISTER_SYMBOLIC(SimplifyShapeOfSubGraph)
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,67 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/symbolic_transformations/reshape_optimizations.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <openvino/core/model.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/shape_of.hpp>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "openvino/core/dimension_tracker.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::op;
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
void label_shape(ov::PartialShape& shape) {
|
||||
auto table = std::make_shared<ov::TableOfEquivalence>(42);
|
||||
auto tracker = ov::DimensionTracker(table);
|
||||
tracker.set_up_for_tracking(shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, FlattenOptimization) {
|
||||
{
|
||||
auto shape = PartialShape::dynamic(4);
|
||||
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44, 45
|
||||
|
||||
auto data = make_shared<v0::Parameter>(element::f32, shape);
|
||||
|
||||
auto shape_of = make_shared<v3::ShapeOf>(data);
|
||||
auto indices = ov::op::v0::Constant::create(element::i64, {2}, {0, 1});
|
||||
auto axis = ov::op::v0::Constant::create(element::i64, {}, {0});
|
||||
|
||||
auto as_is_dims = make_shared<v1::Gather>(shape_of, indices, axis);
|
||||
|
||||
auto merged_dim = make_shared<v1::Multiply>(
|
||||
make_shared<v1::Gather>(shape_of, ov::op::v0::Constant::create(element::i64, {1}, {2}), axis),
|
||||
make_shared<v1::Gather>(shape_of, ov::op::v0::Constant::create(element::i64, {1}, {3}), axis));
|
||||
|
||||
auto pattern = make_shared<v0::Concat>(OutputVector{as_is_dims, merged_dim}, 0);
|
||||
|
||||
auto reshape = make_shared<v1::Reshape>(data, pattern, false);
|
||||
|
||||
model = make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<pass::ReshapeOptimizations>();
|
||||
}
|
||||
{
|
||||
auto shape = PartialShape::dynamic(4);
|
||||
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44, 45
|
||||
|
||||
auto data = make_shared<v0::Parameter>(element::f32, shape);
|
||||
auto pattern = ov::op::v0::Constant::create(element::i64, {3}, {0, 0, -1});
|
||||
|
||||
auto reshape = make_shared<v1::Reshape>(data, pattern, true);
|
||||
|
||||
model_ref = make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user