ReshapeBMatMul and ReshapeAMatMul: avoid circular dependencies creation (#20771)
This commit is contained in:
parent
da1f0199a0
commit
2932e9e938
@ -61,8 +61,11 @@ ov::pass::ReshapeAMatMul::ReshapeAMatMul() {
|
||||
auto other_input_label = pattern::any_input();
|
||||
auto reshape_input_label = pattern::any_input();
|
||||
auto reshape_pattern_label = pattern::any_input();
|
||||
auto reshape_predicate = [](ov::Output<ov::Node> output) -> bool {
|
||||
return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output);
|
||||
};
|
||||
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
||||
ov::pass::pattern::rank_equals(2));
|
||||
reshape_predicate);
|
||||
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({reshape_label, other_input_label});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||
@ -83,8 +86,11 @@ ov::pass::ReshapeBMatMul::ReshapeBMatMul() {
|
||||
auto other_input_label = pattern::any_input();
|
||||
auto reshape_input_label = pattern::any_input();
|
||||
auto reshape_pattern_label = pattern::any_input();
|
||||
auto reshape_predicate = [](ov::Output<ov::Node> output) -> bool {
|
||||
return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output);
|
||||
};
|
||||
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
||||
ov::pass::pattern::rank_equals(2));
|
||||
reshape_predicate);
|
||||
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({other_input_label, reshape_label});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||
|
@ -10,11 +10,14 @@
|
||||
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
#include "common_test_utils/graph_comparator.hpp"
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include "ie_common.h"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/reduce_max.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/variadic_split.hpp"
|
||||
@ -357,3 +360,35 @@ TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulWithAttrFuse) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SmartReshapeReshapeAMatMulSeveralConsumers) {
|
||||
// Reshape has 2 consumers: matmul and reduce.
|
||||
// Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied
|
||||
auto data_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3, 2, 3});
|
||||
auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {3, 6});
|
||||
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_A, reshape_const, false);
|
||||
|
||||
auto data_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{6, 12});
|
||||
auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1});
|
||||
auto reduce = std::make_shared<ov::op::v1::ReduceMax>(reshape, reduction_axes);
|
||||
auto sum = std::make_shared<ov::op::v1::Add>(data_B, reduce);
|
||||
auto matmul = std::make_shared<ov::op::v0::MatMul>(reshape, sum);
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B});
|
||||
manager.register_pass<ov::pass::ReshapeAMatMul>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SmartReshapeReshapeBMatMulSeveralConsumers) {
|
||||
// Reshape has 2 consumers: matmul and reduce.
|
||||
// Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied
|
||||
auto data_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3, 2, 3});
|
||||
auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {6, 3});
|
||||
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_B, reshape_const, false);
|
||||
|
||||
auto data_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{12, 6});
|
||||
auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1});
|
||||
auto reduce = std::make_shared<ov::op::v1::ReduceMax>(reshape, reduction_axes);
|
||||
auto sum = std::make_shared<ov::op::v1::Add>(data_A, reduce);
|
||||
auto matmul = std::make_shared<ov::op::v0::MatMul>(sum, reshape);
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B});
|
||||
manager.register_pass<ov::pass::ReshapeBMatMul>();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user