From 2932e9e9381c8f07cef200931470fe19044590d5 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Tue, 31 Oct 2023 12:00:52 +0100 Subject: [PATCH] ReshapeBMatMul and ReshapeAMatMul: avoid circular dependencies creation (#20771) --- .../smart_reshape/matmul_sr.cpp | 10 ++++-- .../tests/functional/matmul_sr_tests.cpp | 35 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/smart_reshape/matmul_sr.cpp b/src/common/transformations/src/transformations/smart_reshape/matmul_sr.cpp index 870b69d9a55..ff7dc8c927d 100644 --- a/src/common/transformations/src/transformations/smart_reshape/matmul_sr.cpp +++ b/src/common/transformations/src/transformations/smart_reshape/matmul_sr.cpp @@ -61,8 +61,11 @@ ov::pass::ReshapeAMatMul::ReshapeAMatMul() { auto other_input_label = pattern::any_input(); auto reshape_input_label = pattern::any_input(); auto reshape_pattern_label = pattern::any_input(); + auto reshape_predicate = [](ov::Output output) -> bool { + return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output); + }; auto reshape_label = ov::pass::pattern::wrap_type({reshape_input_label, reshape_pattern_label}, - ov::pass::pattern::rank_equals(2)); + reshape_predicate); auto matmul_label = ov::pass::pattern::wrap_type({reshape_label, other_input_label}); matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool { @@ -83,8 +86,11 @@ ov::pass::ReshapeBMatMul::ReshapeBMatMul() { auto other_input_label = pattern::any_input(); auto reshape_input_label = pattern::any_input(); auto reshape_pattern_label = pattern::any_input(); + auto reshape_predicate = [](ov::Output output) -> bool { + return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output); + }; auto reshape_label = ov::pass::pattern::wrap_type({reshape_input_label, reshape_pattern_label}, - ov::pass::pattern::rank_equals(2)); + reshape_predicate); auto matmul_label = ov::pass::pattern::wrap_type({other_input_label, reshape_label}); matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool { diff --git a/src/inference/tests/functional/matmul_sr_tests.cpp b/src/inference/tests/functional/matmul_sr_tests.cpp index 27a294e656e..3d17cfd915f 100644 --- a/src/inference/tests/functional/matmul_sr_tests.cpp +++ b/src/inference/tests/functional/matmul_sr_tests.cpp @@ -10,11 +10,14 @@ #include "cnn_network_ngraph_impl.hpp" #include "common_test_utils/graph_comparator.hpp" +#include "common_test_utils/ov_test_utils.hpp" #include "common_test_utils/test_common.hpp" #include "ie_common.h" +#include "openvino/op/add.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/parameter.hpp" +#include "openvino/op/reduce_max.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/variadic_split.hpp" @@ -357,3 +360,35 @@ TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulWithAttrFuse) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; } + +TEST_F(TransformationTestsF, SmartReshapeReshapeAMatMulSeveralConsumers) { + // Reshape has 2 consumers: matmul and reduce. + // Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied + auto data_A = std::make_shared(ov::element::f32, ov::Shape{3, 2, 3}); + auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {3, 6}); + auto reshape = std::make_shared(data_A, reshape_const, false); + + auto data_B = std::make_shared(ov::element::f32, ov::Shape{6, 12}); + auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1}); + auto reduce = std::make_shared(reshape, reduction_axes); + auto sum = std::make_shared(data_B, reduce); + auto matmul = std::make_shared(reshape, sum); + model = std::make_shared(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, SmartReshapeReshapeBMatMulSeveralConsumers) { + // Reshape has 2 consumers: matmul and reduce. + // Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied + auto data_B = std::make_shared(ov::element::f32, ov::Shape{3, 2, 3}); + auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {6, 3}); + auto reshape = std::make_shared(data_B, reshape_const, false); + + auto data_A = std::make_shared(ov::element::f32, ov::Shape{12, 6}); + auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1}); + auto reduce = std::make_shared(reshape, reduction_axes); + auto sum = std::make_shared(data_A, reduce); + auto matmul = std::make_shared(sum, reshape); + model = std::make_shared(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B}); + manager.register_pass(); +}