SmartReshape: ReshapeMatMul transformations are fixed (#19987)

* SmartReshape: ReshapeMatMul transformations are fixed

* clang-format fixes
This commit is contained in:
Vladislav Golubev 2023-09-21 19:49:07 +02:00 committed by GitHub
parent d0ef28e541
commit 0e0e1b0ee6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 28 deletions

View File

@ -28,10 +28,8 @@ bool relax_hc_reshape_followed_by_matmul(const ov::pass::pattern::PatternValueMa
const std::shared_ptr<ov::Node>& other_input_label,
const std::shared_ptr<ov::Node>& reshape_pattern_label,
bool reshape_is_A_input) {
const auto& reshape_rank = pattern_to_output.at(reshape_label).get_partial_shape().rank();
const auto& matmul =
std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
if (!matmul || reshape_rank.is_dynamic() || reshape_rank.get_length() != 2)
const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
if (!matmul)
return false;
const auto& shape_source = pattern_to_output.at(other_input_label);
if (ov::is_type<ov::op::v1::Transpose>(shape_source.get_node_shared_ptr()) ||
@ -39,20 +37,15 @@ bool relax_hc_reshape_followed_by_matmul(const ov::pass::pattern::PatternValueMa
// avoiding loop creation
return false;
const auto& raw_idx =
reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
OPENVINO_SUPPRESS_DEPRECATED_START
const auto& idx = ov::normalize_axes(matmul->description(), {raw_idx}, reshape_rank);
OPENVINO_SUPPRESS_DEPRECATED_END
const auto& C =
std::make_shared<ov::op::v1::Gather>(std::make_shared<ov::op::v3::ShapeOf>(shape_source),
ov::op::v0::Constant::create(ov::element::i64, {idx.size()}, idx),
ov::op::v0::Constant::create(ov::element::i64, {}, {0}));
const auto& N = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
const auto& pattern_vector =
reshape_is_A_input ? (matmul->get_transpose_a() ? ov::OutputVector({C, N}) : ov::OutputVector({N, C}))
: (matmul->get_transpose_b() ? ov::OutputVector({N, C}) : ov::OutputVector({C, N}));
const auto& new_reshape_pattern = std::make_shared<ov::op::v0::Concat>(pattern_vector, 0);
const auto idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
const auto C = std::make_shared<ov::op::v8::Gather>(std::make_shared<ov::op::v3::ShapeOf>(shape_source),
ov::op::v0::Constant::create(ov::element::i64, {1}, {idx}),
ov::op::v0::Constant::create(ov::element::i64, {}, {0}));
const auto N = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
const auto pattern_vector = reshape_is_A_input
? (matmul->get_transpose_a() ? ov::OutputVector({C, N}) : ov::OutputVector({N, C}))
: (matmul->get_transpose_b() ? ov::OutputVector({N, C}) : ov::OutputVector({C, N}));
const auto new_reshape_pattern = std::make_shared<ov::op::v0::Concat>(pattern_vector, 0);
auto reshape_pattern = pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr();
new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name());
@ -68,8 +61,8 @@ ov::pass::ReshapeAMatMul::ReshapeAMatMul() {
auto other_input_label = pattern::any_input();
auto reshape_input_label = pattern::any_input();
auto reshape_pattern_label = pattern::any_input();
auto reshape_label =
ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label});
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
ov::pass::pattern::rank_equals(2));
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({reshape_label, other_input_label});
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
@ -90,8 +83,8 @@ ov::pass::ReshapeBMatMul::ReshapeBMatMul() {
auto other_input_label = pattern::any_input();
auto reshape_input_label = pattern::any_input();
auto reshape_pattern_label = pattern::any_input();
auto reshape_label =
ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label});
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
ov::pass::pattern::rank_equals(2));
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({other_input_label, reshape_label});
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {

View File

@ -26,7 +26,7 @@ using namespace testing;
namespace {
using reshape_map = std::map<std::string, std::vector<size_t>>;
using reshape_map = std::map<std::string, ov::PartialShape>;
struct ReshapeMatMulTestCase {
bool reshape_is_A_input;
@ -75,8 +75,10 @@ public:
{
auto input_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, test_case.A_shape);
input_A->set_friendly_name("input_A");
input_A->output(0).set_names({"input_A"});
auto input_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, test_case.B_shape);
input_B->set_friendly_name("input_B");
input_B->output(0).set_names({"input_B"});
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i64,
ov::Shape{test_case.reshape_pattern.size()},
@ -99,15 +101,13 @@ public:
ov::ResultVector results = {result};
model = std::make_shared<ov::Model>(results, params);
}
InferenceEngine::details::CNNNetworkNGraphImpl network(model);
const auto& resp = network.reshape(test_case.new_shapes, nullptr);
ASSERT_EQ(resp, InferenceEngine::StatusCode::OK);
ASSERT_NO_THROW(model->reshape(test_case.new_shapes));
}
};
TEST_P(SmartReshapeMatMulTests, ReshapeMatMul) {}
// clang-format off
INSTANTIATE_TEST_SUITE_P(
OVModel,
SmartReshapeMatMulTests,
@ -116,11 +116,14 @@ INSTANTIATE_TEST_SUITE_P(
ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}},
ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}},
ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}},
ReshapeMatMulTestCase{true, {-1, 30, 40}, {-1, 1, 1200}, {1200, 1200}, false, true, {{"input_A", {1200, 30, 40}}}},
ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}},
ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}},
ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}},
ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}}),
ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}},
ReshapeMatMulTestCase{false, {-1, 1, 1200}, {-1, 30, 40}, {1200, 1200}, false, false, {{"input_B", {1200, 30, 40}}}}),
SmartReshapeMatMulTests::getTestCaseName);
// clang-format on
} // namespace
TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulFuse) {