ReshapeBMatMul and ReshapeAMatMul: avoid circular dependencies creation (#20771)
This commit is contained in:
parent
da1f0199a0
commit
2932e9e938
@ -61,8 +61,11 @@ ov::pass::ReshapeAMatMul::ReshapeAMatMul() {
|
|||||||
auto other_input_label = pattern::any_input();
|
auto other_input_label = pattern::any_input();
|
||||||
auto reshape_input_label = pattern::any_input();
|
auto reshape_input_label = pattern::any_input();
|
||||||
auto reshape_pattern_label = pattern::any_input();
|
auto reshape_pattern_label = pattern::any_input();
|
||||||
|
auto reshape_predicate = [](ov::Output<ov::Node> output) -> bool {
|
||||||
|
return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output);
|
||||||
|
};
|
||||||
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
||||||
ov::pass::pattern::rank_equals(2));
|
reshape_predicate);
|
||||||
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({reshape_label, other_input_label});
|
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({reshape_label, other_input_label});
|
||||||
|
|
||||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||||
@ -83,8 +86,11 @@ ov::pass::ReshapeBMatMul::ReshapeBMatMul() {
|
|||||||
auto other_input_label = pattern::any_input();
|
auto other_input_label = pattern::any_input();
|
||||||
auto reshape_input_label = pattern::any_input();
|
auto reshape_input_label = pattern::any_input();
|
||||||
auto reshape_pattern_label = pattern::any_input();
|
auto reshape_pattern_label = pattern::any_input();
|
||||||
|
auto reshape_predicate = [](ov::Output<ov::Node> output) -> bool {
|
||||||
|
return ov::pass::pattern::rank_equals(2)(output) && ov::pass::pattern::consumers_count(1)(output);
|
||||||
|
};
|
||||||
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
||||||
ov::pass::pattern::rank_equals(2));
|
reshape_predicate);
|
||||||
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({other_input_label, reshape_label});
|
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({other_input_label, reshape_label});
|
||||||
|
|
||||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||||
|
@ -10,11 +10,14 @@
|
|||||||
|
|
||||||
#include "cnn_network_ngraph_impl.hpp"
|
#include "cnn_network_ngraph_impl.hpp"
|
||||||
#include "common_test_utils/graph_comparator.hpp"
|
#include "common_test_utils/graph_comparator.hpp"
|
||||||
|
#include "common_test_utils/ov_test_utils.hpp"
|
||||||
#include "common_test_utils/test_common.hpp"
|
#include "common_test_utils/test_common.hpp"
|
||||||
#include "ie_common.h"
|
#include "ie_common.h"
|
||||||
|
#include "openvino/op/add.hpp"
|
||||||
#include "openvino/op/constant.hpp"
|
#include "openvino/op/constant.hpp"
|
||||||
#include "openvino/op/matmul.hpp"
|
#include "openvino/op/matmul.hpp"
|
||||||
#include "openvino/op/parameter.hpp"
|
#include "openvino/op/parameter.hpp"
|
||||||
|
#include "openvino/op/reduce_max.hpp"
|
||||||
#include "openvino/op/reshape.hpp"
|
#include "openvino/op/reshape.hpp"
|
||||||
#include "openvino/op/transpose.hpp"
|
#include "openvino/op/transpose.hpp"
|
||||||
#include "openvino/op/variadic_split.hpp"
|
#include "openvino/op/variadic_split.hpp"
|
||||||
@ -357,3 +360,35 @@ TEST(SmartReshapeTransposeMatMulTests, TransposeBothMatMulWithAttrFuse) {
|
|||||||
auto res = compare_functions(f, f_ref);
|
auto res = compare_functions(f, f_ref);
|
||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, SmartReshapeReshapeAMatMulSeveralConsumers) {
|
||||||
|
// Reshape has 2 consumers: matmul and reduce.
|
||||||
|
// Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied
|
||||||
|
auto data_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3, 2, 3});
|
||||||
|
auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {3, 6});
|
||||||
|
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_A, reshape_const, false);
|
||||||
|
|
||||||
|
auto data_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{6, 12});
|
||||||
|
auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1});
|
||||||
|
auto reduce = std::make_shared<ov::op::v1::ReduceMax>(reshape, reduction_axes);
|
||||||
|
auto sum = std::make_shared<ov::op::v1::Add>(data_B, reduce);
|
||||||
|
auto matmul = std::make_shared<ov::op::v0::MatMul>(reshape, sum);
|
||||||
|
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B});
|
||||||
|
manager.register_pass<ov::pass::ReshapeAMatMul>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, SmartReshapeReshapeBMatMulSeveralConsumers) {
|
||||||
|
// Reshape has 2 consumers: matmul and reduce.
|
||||||
|
// Since reshape movement leads to loop creation (circular dependencies), the transformation can't be applied
|
||||||
|
auto data_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3, 2, 3});
|
||||||
|
auto reshape_const = ov::op::v0::Constant::create(ov::element::i32, {2}, {6, 3});
|
||||||
|
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_B, reshape_const, false);
|
||||||
|
|
||||||
|
auto data_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{12, 6});
|
||||||
|
auto reduction_axes = ov::op::v0::Constant::create(ov::element::i32, {2}, {0, 1});
|
||||||
|
auto reduce = std::make_shared<ov::op::v1::ReduceMax>(reshape, reduction_axes);
|
||||||
|
auto sum = std::make_shared<ov::op::v1::Add>(data_A, reduce);
|
||||||
|
auto matmul = std::make_shared<ov::op::v0::MatMul>(sum, reshape);
|
||||||
|
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{data_A, data_B});
|
||||||
|
manager.register_pass<ov::pass::ReshapeBMatMul>();
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user