ReshapeBMatMul and ReshapeAMatMul: avoid circular dependencies creation (#20771)

This commit is contained in:
Vladislav Golubev 2023-10-31 12:00:52 +01:00 committed by GitHub
parent da1f0199a0
commit 2932e9e938
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 2 deletions

View File

@ -61,8 +61,11 @@ ov::pass::ReshapeAMatMul::ReshapeAMatMul() {
auto other_input_label = pattern::any_input();
auto reshape_input_label = pattern::any_input();
auto reshape_pattern_label = pattern::any_input();
auto reshape_predicate = [](ov::Output<ov::Node> output) -> bool {
return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output);
};
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
ov::pass::pattern::rank_equals(2));
reshape_predicate);
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({reshape_label, other_input_label});
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
@ -83,8 +86,11 @@ ov::pass::ReshapeBMatMul::ReshapeBMatMul() {
auto other_input_label = pattern::any_input();
auto reshape_input_label = pattern::any_input();
auto reshape_pattern_label = pattern::any_input();
auto reshape_predicate = [](ov::Output<ov::Node> output) -> bool {
return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output);
};
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
ov::pass::pattern::rank_equals(2));
reshape_predicate);
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({other_input_label, reshape_label});
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {

View File

@ -10,11 +10,14 @@
#include "cnn_network_ngraph_impl.hpp"
#include "common_test_utils/graph_comparator.hpp"
#include "common_test_utils/ov_test_utils.hpp"
#include "common_test_utils/test_common.hpp"
#include "ie_common.h"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reduce_max.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/variadic_split.hpp"
@ -357,3 +360,35 @@ TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulWithAttrFuse) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST_F(TransformationTestsF, SmartReshapeReshapeAMatMulSeveralConsumers) {
// Reshape has 2 consumers: matmul and reduce.
// Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied
auto data_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3, 2, 3});
auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {3, 6});
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_A, reshape_const, false);
auto data_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{6, 12});
auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1});
auto reduce = std::make_shared<ov::op::v1::ReduceMax>(reshape, reduction_axes);
auto sum = std::make_shared<ov::op::v1::Add>(data_B, reduce);
auto matmul = std::make_shared<ov::op::v0::MatMul>(reshape, sum);
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B});
manager.register_pass<ov::pass::ReshapeAMatMul>();
}
TEST_F(TransformationTestsF, SmartReshapeReshapeBMatMulSeveralConsumers) {
// Reshape has 2 consumers: matmul and reduce.
// Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied
auto data_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3, 2, 3});
auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {6, 3});
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_B, reshape_const, false);
auto data_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{12, 6});
auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1});
auto reduce = std::make_shared<ov::op::v1::ReduceMax>(reshape, reduction_axes);
auto sum = std::make_shared<ov::op::v1::Add>(data_A, reduce);
auto matmul = std::make_shared<ov::op::v0::MatMul>(sum, reshape);
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B});
manager.register_pass<ov::pass::ReshapeBMatMul>();
}