SmartReshape: ReshapeMatMul transformations are fixed (#19987)
* SmartReshape: ReshapeMatMul transformations are fixed * clang-format fixes
This commit is contained in:
parent
d0ef28e541
commit
0e0e1b0ee6
@ -28,10 +28,8 @@ bool relax_hc_reshape_followed_by_matmul(const ov::pass::pattern::PatternValueMa
|
||||
const std::shared_ptr<ov::Node>& other_input_label,
|
||||
const std::shared_ptr<ov::Node>& reshape_pattern_label,
|
||||
bool reshape_is_A_input) {
|
||||
const auto& reshape_rank = pattern_to_output.at(reshape_label).get_partial_shape().rank();
|
||||
const auto& matmul =
|
||||
std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
|
||||
if (!matmul || reshape_rank.is_dynamic() || reshape_rank.get_length() != 2)
|
||||
const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
|
||||
if (!matmul)
|
||||
return false;
|
||||
const auto& shape_source = pattern_to_output.at(other_input_label);
|
||||
if (ov::is_type<ov::op::v1::Transpose>(shape_source.get_node_shared_ptr()) ||
|
||||
@ -39,20 +37,15 @@ bool relax_hc_reshape_followed_by_matmul(const ov::pass::pattern::PatternValueMa
|
||||
// avoiding loop creation
|
||||
return false;
|
||||
|
||||
const auto& raw_idx =
|
||||
reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto& idx = ov::normalize_axes(matmul->description(), {raw_idx}, reshape_rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
const auto& C =
|
||||
std::make_shared<ov::op::v1::Gather>(std::make_shared<ov::op::v3::ShapeOf>(shape_source),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {idx.size()}, idx),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {}, {0}));
|
||||
const auto& N = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
|
||||
const auto& pattern_vector =
|
||||
reshape_is_A_input ? (matmul->get_transpose_a() ? ov::OutputVector({C, N}) : ov::OutputVector({N, C}))
|
||||
: (matmul->get_transpose_b() ? ov::OutputVector({N, C}) : ov::OutputVector({C, N}));
|
||||
const auto& new_reshape_pattern = std::make_shared<ov::op::v0::Concat>(pattern_vector, 0);
|
||||
const auto idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
|
||||
const auto C = std::make_shared<ov::op::v8::Gather>(std::make_shared<ov::op::v3::ShapeOf>(shape_source),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {idx}),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {}, {0}));
|
||||
const auto N = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
|
||||
const auto pattern_vector = reshape_is_A_input
|
||||
? (matmul->get_transpose_a() ? ov::OutputVector({C, N}) : ov::OutputVector({N, C}))
|
||||
: (matmul->get_transpose_b() ? ov::OutputVector({N, C}) : ov::OutputVector({C, N}));
|
||||
const auto new_reshape_pattern = std::make_shared<ov::op::v0::Concat>(pattern_vector, 0);
|
||||
|
||||
auto reshape_pattern = pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr();
|
||||
new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name());
|
||||
@ -68,8 +61,8 @@ ov::pass::ReshapeAMatMul::ReshapeAMatMul() {
|
||||
auto other_input_label = pattern::any_input();
|
||||
auto reshape_input_label = pattern::any_input();
|
||||
auto reshape_pattern_label = pattern::any_input();
|
||||
auto reshape_label =
|
||||
ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label});
|
||||
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
||||
ov::pass::pattern::rank_equals(2));
|
||||
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({reshape_label, other_input_label});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||
@ -90,8 +83,8 @@ ov::pass::ReshapeBMatMul::ReshapeBMatMul() {
|
||||
auto other_input_label = pattern::any_input();
|
||||
auto reshape_input_label = pattern::any_input();
|
||||
auto reshape_pattern_label = pattern::any_input();
|
||||
auto reshape_label =
|
||||
ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label});
|
||||
auto reshape_label = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({reshape_input_label, reshape_pattern_label},
|
||||
ov::pass::pattern::rank_equals(2));
|
||||
auto matmul_label = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({other_input_label, reshape_label});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||
|
@ -26,7 +26,7 @@ using namespace testing;
|
||||
|
||||
namespace {
|
||||
|
||||
using reshape_map = std::map<std::string, std::vector<size_t>>;
|
||||
using reshape_map = std::map<std::string, ov::PartialShape>;
|
||||
|
||||
struct ReshapeMatMulTestCase {
|
||||
bool reshape_is_A_input;
|
||||
@ -75,8 +75,10 @@ public:
|
||||
{
|
||||
auto input_A = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, test_case.A_shape);
|
||||
input_A->set_friendly_name("input_A");
|
||||
input_A->output(0).set_names({"input_A"});
|
||||
auto input_B = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, test_case.B_shape);
|
||||
input_B->set_friendly_name("input_B");
|
||||
input_B->output(0).set_names({"input_B"});
|
||||
|
||||
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i64,
|
||||
ov::Shape{test_case.reshape_pattern.size()},
|
||||
@ -99,15 +101,13 @@ public:
|
||||
ov::ResultVector results = {result};
|
||||
model = std::make_shared<ov::Model>(results, params);
|
||||
}
|
||||
|
||||
InferenceEngine::details::CNNNetworkNGraphImpl network(model);
|
||||
const auto& resp = network.reshape(test_case.new_shapes, nullptr);
|
||||
ASSERT_EQ(resp, InferenceEngine::StatusCode::OK);
|
||||
ASSERT_NO_THROW(model->reshape(test_case.new_shapes));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(SmartReshapeMatMulTests, ReshapeMatMul) {}
|
||||
|
||||
// clang-format off
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
OVModel,
|
||||
SmartReshapeMatMulTests,
|
||||
@ -116,11 +116,14 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
ReshapeMatMulTestCase{true, {1, 20, 30}, {40, 30}, {20, -1}, false, true, {{"input_A", {2, 20, 30}}}},
|
||||
ReshapeMatMulTestCase{true, {1, 30, 20}, {30, 20}, {-1, 20}, true, false, {{"input_A", {2, 30, 20}}}},
|
||||
ReshapeMatMulTestCase{true, {1, 30, 20}, {40, 30}, {-1, 20}, true, true, {{"input_A", {2, 30, 20}}}},
|
||||
ReshapeMatMulTestCase{true, {-1, 30, 40}, {-1, 1, 1200}, {1200, 1200}, false, true, {{"input_A", {1200, 30, 40}}}},
|
||||
ReshapeMatMulTestCase{false, {20, 30}, {1, 30, 40}, {-1, 40}, false, false, {{"input_B", {2, 30, 40}}}},
|
||||
ReshapeMatMulTestCase{false, {20, 30}, {1, 40, 30}, {40, -1}, false, true, {{"input_B", {2, 40, 30}}}},
|
||||
ReshapeMatMulTestCase{false, {30, 20}, {1, 30, 40}, {-1, 40}, true, false, {{"input_B", {2, 30, 40}}}},
|
||||
ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}}),
|
||||
ReshapeMatMulTestCase{false, {30, 20}, {1, 40, 30}, {40, -1}, true, true, {{"input_B", {2, 40, 30}}}},
|
||||
ReshapeMatMulTestCase{false, {-1, 1, 1200}, {-1, 30, 40}, {1200, 1200}, false, false, {{"input_B", {1200, 30, 40}}}}),
|
||||
SmartReshapeMatMulTests::getTestCaseName);
|
||||
// clang-format on
|
||||
} // namespace
|
||||
|
||||
TEST(SmartReshapeTransposeMatMulTests, TransposeAMatMulFuse) {
|
||||
|
Loading…
Reference in New Issue
Block a user