From 67dc220d380cf3c59625890d51458004f199edca Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 14 Jun 2023 11:55:24 +0400 Subject: [PATCH] [Snippets] Added support of MatMuls with transposed inputs (#17819) --- .../snippets/pass/common_optimizations.hpp | 9 +- .../pass/explicit_transpose_matmul_inputs.hpp | 22 ++- .../src/pass/common_optimizations.cpp | 78 ++++++++-- .../pass/explicit_transpose_matmul_inputs.cpp | 138 ++++++++++-------- .../snippets/src/pass/mha_tokenization.cpp | 108 ++++++-------- .../tests/src/pass/mha_tokenization.cpp | 30 +++- .../shared_tests_instances/snippets/mha.cpp | 15 ++ .../plugin/shared/include/snippets/mha.hpp | 28 ++-- .../plugin/shared/src/snippets/mha.cpp | 58 ++++---- .../include/subgraph_mha.hpp | 25 ++++ .../src/subgraph_mha.cpp | 66 +++++++++ 11 files changed, 394 insertions(+), 183 deletions(-) diff --git a/src/common/snippets/include/snippets/pass/common_optimizations.hpp b/src/common/snippets/include/snippets/pass/common_optimizations.hpp index 08b07339f5f..2961603077f 100644 --- a/src/common/snippets/include/snippets/pass/common_optimizations.hpp +++ b/src/common/snippets/include/snippets/pass/common_optimizations.hpp @@ -5,7 +5,8 @@ #pragma once #include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pattern/matcher.hpp" + +#include "snippets/op/subgraph.hpp" namespace ov { namespace snippets { @@ -15,6 +16,12 @@ class CommonOptimizations : public ov::pass::MatcherPass { public: OPENVINO_RTTI("CommonOptimizations", "0"); CommonOptimizations(); + +private: + // Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body + void ExtractConstants(const std::shared_ptr& subgraph); + // Move up unsupported Transposes on Parameter outputs from body + void ExtractUnsupportedTransposes(const std::shared_ptr& subgraph); }; } // namespace pass diff --git a/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp b/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp index dbad1a714b8..378128d9014 100644 --- a/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp +++ b/src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp @@ -5,7 +5,6 @@ #pragma once #include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pattern/matcher.hpp" namespace ov { namespace snippets { @@ -13,18 +12,25 @@ namespace pass { /** * @interface ExplicitTransposeMatMulInputs - * @brief At the moment Snippets supports Transpose only with order {0, 2, 3, 1}, - * so if there is pattern in graph: - * in0 Transpose{0, 2, 1, 3} - * \ / - * MatMul[false, true] - * We can set false in MatMul parameter `transposed_b` and - * change Transpose order to {0, 2, 3, 1} which is supported by Snippets + * @brief The pass extracts explicit Transpose node from MatMul with transposed_ and moves it to Parameter. + * If there is another Transpose, the pass fuses extracted Transpose and existing Transpose. + * For example, At the moment Snippets supports Transpose only with order {0, 2, 3, 1}, so if there is pattern in graph: + * in0 Transpose{0, 2, 1, 3} + * \ / + * MatMul[false, true] + * We can set `false` in MatMul parameter `transposed_b` and change Transpose order to {0, 2, 3, 1} which is supported by Snippets * @ingroup snippets */ class ExplicitTransposeMatMulInputs: public ov::pass::MatcherPass { public: + OPENVINO_RTTI("ExplicitTransposeMatMulInputs", "0"); ExplicitTransposeMatMulInputs(); + + // Return `True` if all inputs (except 0-th input) have scalar shape. Otherwise returns `False` + static bool are_weights_scalar(const std::shared_ptr& node); + +private: + static void extract(const ov::Input& input); }; } // namespace pass diff --git a/src/common/snippets/src/pass/common_optimizations.cpp b/src/common/snippets/src/pass/common_optimizations.cpp index da55629055a..180207aa841 100644 --- a/src/common/snippets/src/pass/common_optimizations.cpp +++ b/src/common/snippets/src/pass/common_optimizations.cpp @@ -4,27 +4,24 @@ #include "snippets/pass/common_optimizations.hpp" -#include -#include "openvino/opsets/opset1.hpp" -#include -#include "openvino/pass/pattern/op/wrap_type.hpp" - -#include "transformations/utils/utils.hpp" #include "snippets/pass/fq_decomposition.hpp" #include "snippets/pass/softmax_reshape_elimination.hpp" #include "snippets/pass/explicit_transpose_matmul_inputs.hpp" +#include "snippets/pass/transpose_decomposition.hpp" +#include "snippets/pass/fuse_transpose_brgemm.hpp" #include "snippets/op/subgraph.hpp" -#include "snippets/utils.hpp" #include "snippets/itt.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + namespace ov { namespace snippets { namespace pass { -// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body -void ConvertConstantsToParameters(const std::shared_ptr& subgraph) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ConvertConstantsToParameters"); +void CommonOptimizations::ExtractConstants(const std::shared_ptr& subgraph) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractConstants"); auto body = subgraph->body_ptr(); ParameterVector new_parameters; @@ -55,6 +52,52 @@ void ConvertConstantsToParameters(const std::shared_ptr& subgraph) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractUnsupportedTransposes"); + const auto& body = subgraph->body_ptr(); + const auto parameters = body->get_parameters(); + // [107806]: If count of Parameters isn't equal to Subgraph inputs, + // we cannot guarantee correct extraction since we don't have correct connections between body I/O and Subgraph I/O. + OPENVINO_ASSERT(parameters.size() == subgraph->input_values().size(), + "Failed to extract unsupported transposes: the count of Parameters isn't equal to Subgraph inputs"); + + bool updated = false; + for (size_t i = 0; i < parameters.size(); ++i) { + const auto& parameter = parameters[i]; + const auto& consumers = parameter->get_output_target_inputs(0); + if (consumers.size() != 1) + continue; + + const auto transpose = ov::as_type_ptr(consumers.begin()->get_node()->shared_from_this()); + if (!transpose) + continue; + + const auto& order = ov::as_type_ptr(transpose->get_input_node_shared_ptr(1)); + if (!order) + continue; + + const auto order_value = order->cast_vector(); + const auto transpose_child = *(transpose->get_output_target_inputs(0).begin()); + const auto is_brgemm_case = ov::is_type(transpose_child.get_node()->shared_from_this()); + // If Transpose is supported (can be decomposed or fused into Brgemm), skip + if ((is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0) || + (TransposeDecomposition::supported_cases.count(order_value) != 0)) + continue; + + // If the transpose isn't supported - we have to extract it from Subgraph + transpose->set_argument(0, subgraph->input_value(i)); + subgraph->set_argument(i, transpose); + transpose_child.replace_source_output(parameter); + // Update shape + parameter->set_partial_shape(transpose->get_output_partial_shape(0)); + updated = true; + } + + if (updated) { + subgraph->validate_and_infer_types(); + } +} + CommonOptimizations::CommonOptimizations() { MATCHER_SCOPE(CommonOptimizations); ov::graph_rewrite_callback callback = [this](ov::pass::pattern::Matcher& m) { @@ -65,10 +108,10 @@ CommonOptimizations::CommonOptimizations() { return false; } - auto body = subgraph->body_ptr(); + const auto& body = subgraph->body_ptr(); const auto is_quantized = subgraph->is_quantized(); - // Firsly we should transform all original Converts inside body to ConvertTruncation to save original behavior. + // Firstly, we should transform all original Converts inside body to ConvertTruncation to save original behavior. // Then if Subgraph contains FakeQuantize we enable specific transformation for quantized subgraphs. ov::pass::Manager manager; manager.register_pass(); @@ -80,15 +123,18 @@ CommonOptimizations::CommonOptimizations() { manager.run_passes(body); // At the moment only non-scalar Constants of FakeQuantize can be inside Subgraph - // so we can enable ConvertConstantsToParameters pass for quantized models + // so we can enable ExtractConstants pass for quantized models if (is_quantized) { - ConvertConstantsToParameters(subgraph); + ExtractConstants(subgraph); + } + // Extract unsupported Transposes from body + if (subgraph->has_domain_sensitive_ops()) { + ExtractUnsupportedTransposes(subgraph); } return true; }; - auto m = std::make_shared(ov::pass::pattern::wrap_type(), - matcher_name); + auto m = std::make_shared(ov::pass::pattern::wrap_type(), matcher_name); this->register_matcher(m, callback); } diff --git a/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp b/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp index 6948f6dfcf3..e98a2c3d57a 100644 --- a/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp +++ b/src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp @@ -2,79 +2,101 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "snippets/pass/explicit_transpose_matmul_inputs.hpp" + +#include "snippets/op/subgraph.hpp" #include "snippets/itt.hpp" -#include "snippets/pass/explicit_transpose_matmul_inputs.hpp" -#include "snippets/pass/transpose_decomposition.hpp" -#include "snippets/op/subgraph.hpp" - -#include "openvino/core/rt_info.hpp" +#include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/core/rt_info.hpp" + +bool ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(const std::shared_ptr& node) { + const auto inputs = node->inputs(); + return std::all_of(inputs.begin() + 1, inputs.end(), + [](const ov::Input& in) { + return in.get_partial_shape().is_static() && ov::shape_size(in.get_shape()) == 1; + }); +} + +void ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input& input) { + auto parent = input.get_source_output().get_node_shared_ptr(); + auto transpose = ov::as_type_ptr(parent); + while (!transpose && !ov::is_type(parent)) { + // We can set supported order and transposed_=false only if ops have scalar shapes to avoid shape mismatching + if (!are_weights_scalar(parent)) + break; + + parent = parent->get_input_node_shared_ptr(0); + transpose = ov::as_type_ptr(parent); + } + + // If there isn't another Transpose, need to create new Transpose + if (transpose) { + const auto transpose_pattern = ov::as_type_ptr(transpose->get_input_node_shared_ptr(1)); + OPENVINO_ASSERT(transpose_pattern, + "ExplicitTransposeMatMulInputs expects existing Transpose with Constant order"); + + auto transposed_order = transpose_pattern->cast_vector(); + OPENVINO_ASSERT(transposed_order.size() > 2, "Incorrect Transpose order for ExplicitTransposeMatMulInputs"); + std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1)); + + auto new_transpose_order = std::make_shared(transpose_pattern->get_element_type(), + ov::Shape{transposed_order.size()}, + transposed_order); + new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name()); + ov::copy_runtime_info(transpose_pattern, new_transpose_order); + transpose->set_argument(1, new_transpose_order); + return; + } + + // Create new Transpose before Parameter + OPENVINO_ASSERT(ov::is_type(parent), + "ExplicitTransposeMatMulInputs expects Parameter in cases when there isn't existing Transpose on input"); + const auto& consumers = parent->get_output_target_inputs(0); + OPENVINO_ASSERT(consumers.size() == 1, + "ExplicitTransposeMatMulInputs expects Parameter with one consumer in cases when there isn't existing Transpose on input"); + // Extract Transpose from MatMul + OPENVINO_ASSERT(input.get_partial_shape().is_static(), "ExplicitTransposeMatMulInputs supports only static shapes"); + const auto rank = input.get_shape().size(); + std::vector transpose_order(rank, 0); + std::iota(transpose_order.begin(), transpose_order.end(), 0); + std::swap(transpose_order[rank - 1], transpose_order[rank - 2]); + + const auto constant_order = std::make_shared(ov::element::i32, ov::Shape{rank}, transpose_order); + const auto new_transpose = std::make_shared(parent, constant_order); // parent is Parameter + const auto consumer_input = *(consumers.begin()); + consumer_input.replace_source_output(new_transpose); +} ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs() { MATCHER_SCOPE(ExplicitTransposeMatMulInputs); - auto m_matmul0 = std::make_shared( + auto m_matmul0 = std::make_shared( ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()), ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape())); register_matcher(std::make_shared(m_matmul0, matcher_name), [=](ov::pass::pattern::Matcher &m) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs") - auto root = m.get_match_root(); - bool rewritten = false; + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs") + auto root = m.get_match_root(); + bool rewritten = false; - auto matmul0 = ov::as_type_ptr(root); - if (!matmul0) - return false; + auto matmul = ov::as_type_ptr(root); + if (!matmul) + return false; - for (size_t i = 0; i < matmul0->get_input_size(); i++) { - if (i == 0 && !matmul0->get_transpose_a()) - continue; - if (i == 1 && !matmul0->get_transpose_b()) - continue; - - auto parent1 = matmul0->get_input_node_shared_ptr(i); - auto transpose1 = ov::as_type_ptr(parent1); - while (!transpose1 && !ov::is_type(parent1)) { - // We can set supported order and transposed_b(false) only if ops have scalar shapes to avoid shape mismatching - const auto parent_count = parent1->inputs().size(); - bool are_weights_scalar = true; - for (size_t j = 1; j < parent_count; ++j) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent1->get_input_shape(j)) == 1; - } - if (!are_weights_scalar) - break; - - parent1 = parent1->get_input_node_shared_ptr(0); - transpose1 = ov::as_type_ptr(parent1); + if (matmul->get_transpose_a()) { + extract(matmul->input(0)); + matmul->set_transpose_a(false); + rewritten |= true; } - if (!transpose1) - continue; - - const auto transpose_pattern = ov::as_type_ptr(transpose1->get_input_node_shared_ptr(1)); - if (!transpose_pattern) - continue; - - auto transposed_order = transpose_pattern->cast_vector(); - std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1)); - if (pass::TransposeDecomposition::supported_cases.count(transposed_order) == 0) - continue; - - auto new_transpose_order = std::make_shared(transpose_pattern->get_element_type(), - ov::Shape{4}, - transposed_order); - new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name()); - ov::copy_runtime_info(transpose_pattern, new_transpose_order); - transpose1->set_argument(1, new_transpose_order); - if (i == 0) { - matmul0->set_transpose_a(false); - } else { - matmul0->set_transpose_b(false); + if (matmul->get_transpose_b()) { + extract(matmul->input(1)); + matmul->set_transpose_b(false); + rewritten |= true; } - rewritten |= true; - } - return rewritten; - }); + return rewritten; + }); } diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index 864341dc417..1f62781c69e 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -7,6 +7,7 @@ #include "snippets/itt.hpp" #include "snippets/pass/collapse_subgraph.hpp" +#include "snippets/pass/explicit_transpose_matmul_inputs.hpp" #include "snippets/op/subgraph.hpp" #include "snippets/op/brgemm.hpp" #include "snippets/utils.hpp" @@ -156,12 +157,7 @@ auto update_intermediate_supported_ops(std::shared_ptr& interm_op, ov: break; // Add node only if there are scalar constants on inputs because of plugin-specific limitation - bool are_weights_scalar = true; - const auto parent_count = parent->get_input_size(); - for (size_t i = 1; i < parent_count; ++i) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; - } - if (!are_weights_scalar) + if (!ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent)) break; ordered_ops.insert(ordered_ops.begin() + shift, parent); @@ -321,22 +317,27 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken /***** Transposes *****/ /* There may be Transpose and Reshape ops on inputs and outputs of MHA-pattern skeleton * We can add them into Subgraph body + * Transpose0 Transpose1 + * \ / + * MatMul0 + * | + * [...] Transpose2 + * \ / + * MatMul1 + * | + * Transpose3 */ - auto tokenize_transpose = [config](const std::shared_ptr& node) -> std::shared_ptr { - return config.mha_token_enable_transpose ? ov::as_type_ptr(node) - : nullptr; - }; - // First input branch of MatMul0 should be executed before second input branch of MatMul0, - // so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1 - bool are_weights_scalar = true; + // so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose0 + // Note: If MatMul0 has transposed_b, we should tokenize only scalars ops from 1st branch + // to move extracted Transpose from MatMul input to body Parameter + auto parent = matmul0->get_input_node_shared_ptr(1); // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1) // only if these ops have scalar shapes on other inputs. // There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false). // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching const auto is_transposed_b_0 = matmul0->get_transpose_b(); - auto parent = matmul0->get_input_node_shared_ptr(1); while (is_supported_intermediate_op(parent)) { // All supported ops have only one output port if (parent->get_output_target_inputs(0).size() != 1) @@ -344,15 +345,8 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // Only if MatMul0 has transposed_b, we have to tokenize scalar ops // to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body - if (is_transposed_b_0) { - const auto parent_count = parent->get_input_size(); - bool are_weights_scalar = true; - for (size_t i = 1; i < parent_count; ++i) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; - } - if (!are_weights_scalar) { - break; - } + if (is_transposed_b_0 && !ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent)) { + break; } // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) @@ -360,53 +354,45 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken if (const auto fq_node = ov::as_type_ptr(parent)) { hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); } + potential_body_params_count += get_potential_body_params(parent); ordered_ops.insert(ordered_ops.begin(), parent); - // TODO [107731] To go always through 0-th port - is it safe? + // [107731] To go always through 0-th port - is it safe? parent = parent->get_input_node_shared_ptr(0); } - const auto transpose1 = tokenize_transpose(parent); - if (is_transposed_b_0) { - if (is_valid_transpose(transpose1, {0, 2, 1, 3})) { - // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order - // only if these ops have scalar shapes on other inputs. - // There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false). - // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching - if (are_weights_scalar) { - ordered_ops.insert(ordered_ops.begin(), transpose1); - } else { - return false; + auto tokenize_transpose = [&](const std::shared_ptr& transpose, + bool is_input_transposed, std::vector order, + const ov::NodeVector::const_iterator& pos) { + // If Transpose has valid order for the Transpose fusing (ExplicitTransposeMatMulInputs pass call), tokenize him. + // Otherwise, skip the Transpose. + if (!is_input_transposed) { + if (is_valid_transpose(transpose, order)) { + ordered_ops.insert(pos, transpose); } - } else { - return false; + return; } - } else { - if (is_valid_transpose(transpose1, {0, 2, 3, 1})) { - ordered_ops.insert(ordered_ops.begin(), transpose1); + auto transposed_order = order; + const auto rank = transposed_order.size(); + if (rank < 2) + return; + std::swap(transposed_order[rank - 1], transposed_order[rank - 2]); + if (is_valid_transpose(transpose, transposed_order)) { + ordered_ops.insert(pos, transpose); } - } + }; - if (transpose1) { - // Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization. - // The Loop will have one Buffer with the same shape both on input and output. - // Need to check for precision to get if we need one more register for Buffer - if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) { - buffer_count++; - } - } + auto get_transpose = [config](const std::shared_ptr& node) -> std::shared_ptr { + return config.mha_token_enable_transpose ? ov::as_type_ptr(node) + : nullptr; + }; - const auto transpose0 = tokenize_transpose(matmul0->get_input_node_shared_ptr(0)); - if (is_valid_transpose(transpose0, {0, 2, 1, 3})) { - ordered_ops.insert(ordered_ops.begin(), transpose0); - } else if (matmul0->get_transpose_a()) { - return false; - } - - const auto transpose2 = tokenize_transpose(matmul1->get_input_node_shared_ptr(1)); - if (is_valid_transpose(transpose2, {0, 2, 1, 3})) { - ordered_ops.push_back(transpose2); - } + const auto transpose1 = get_transpose(parent); + const auto transpose0 = get_transpose(matmul0->get_input_node_shared_ptr(0)); + const auto transpose2 = get_transpose(matmul1->get_input_node_shared_ptr(1)); + tokenize_transpose(transpose1, is_transposed_b_0, {0, 2, 3, 1}, ordered_ops.begin()); + tokenize_transpose(transpose0, matmul0->get_transpose_a(), {0, 2, 1, 3}, ordered_ops.begin()); + tokenize_transpose(transpose2, matmul1->get_transpose_b(), {0, 2, 1, 3}, ordered_ops.end()); ordered_ops.push_back(matmul1); bool are_ops_after_matmul1 = false; @@ -439,7 +425,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // // Transpose3 if (!are_ops_after_matmul1) { - auto transpose3 = tokenize_transpose(child); + auto transpose3 = get_transpose(child); if (is_valid_transpose(transpose3, {0, 2, 1, 3}) && transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3 ordered_ops.push_back(transpose3); diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index 68956a2a626..86e211d5c1d 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -7,7 +7,7 @@ #include #include "snippets/pass/tokenization.hpp" #include "snippets/pass/mha_tokenization.hpp" -#include "snippets/pass/explicit_transpose_matmul_inputs.hpp" +#include "snippets/pass/common_optimizations.hpp" namespace ov { namespace test { @@ -15,9 +15,10 @@ namespace snippets { void TokenizeMHASnippetsTests::run() { ASSERT_TRUE(function); - std::string name; manager.register_pass(); manager.register_pass(); + manager.register_pass(); + disable_rt_info_check(); } TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) { @@ -43,6 +44,31 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) { run(); } +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction) { + const auto& f = MHATransposedInputFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 128, 12, 64}}, true); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction_and_unsupported_existing_transpose) { + const auto& f = MHATransposedInputFunction(std::vector{{1, 128, 12, 64}, {1, 12, 64, 128}, {1, 128, 12, 64}}, true, + std::vector{0, 3, 1, 2}); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_fusion) { + const auto& f = MHATransposedInputFunction(std::vector{{1, 128, 12, 64}, {1, 64, 128, 12}, {1, 128, 12, 64}}, false, + std::vector{0, 2, 1, 3}); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + + + } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index c210c1bb6ae..6b3f525f719 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -227,6 +227,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ, ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), MHA::getTestCaseName); +const std::vector> inputShapesTransposedB = { + {{1, 12, 12, 64}, {1, 12, 48, 64}, {1, 12, 48, 64}} +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHATransposedB, MHATransposedB, + ::testing::Combine( + ::testing::ValuesIn(inputShapesTransposedB), + ::testing::Values(std::vector{}), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(std::map{})), + MHA::getTestCaseName); } // namespace } // namespace snippets diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index dde9394869f..2f9b950c798 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -5,6 +5,7 @@ #pragma once #include "shared_test_classes/base/snippets_test_utils.hpp" +#include "ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp" namespace ov { namespace test { @@ -30,7 +31,7 @@ protected: void SetUp() override; void generate_inputs(const std::vector& targetInputStaticShapes) override; - virtual void init_subgraph(); + virtual std::shared_ptr get_subgraph(); bool m_with_mul = false; std::vector m_input_types; @@ -39,39 +40,42 @@ protected: class MHASelect : public MHA { protected: void generate_inputs(const std::vector& targetInputStaticShapes) override; - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAWOTransposeOnInputs : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAWOTranspose : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; +}; + +class MHAMulAdd : public MHA { + std::shared_ptr get_subgraph() override; +}; + +class MHATransposedB : public MHA { + std::shared_ptr get_subgraph() override; }; class MHAINT8MatMul : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAFQAfterMatMul : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; class MHAFQ : public MHA { protected: - void init_subgraph() override; + std::shared_ptr get_subgraph() override; }; -class MHAMulAdd : public MHA { - void init_subgraph() override; -}; - - } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 9e8bd6c4d79..d5924f2cb58 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -51,7 +51,8 @@ void MHA::SetUp() { std::tie(inputShapes, m_input_types, prc, m_with_mul, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); - init_subgraph(); + const auto subgraph_model = get_subgraph(); + function = subgraph_model->getOriginal(); configuration.insert(additionalConfig.begin(), additionalConfig.end()); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { @@ -76,9 +77,8 @@ void MHA::generate_inputs(const std::vector& targetInputStaticSha } } -void MHA::init_subgraph() { - auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_input_types, m_with_mul); - function = f.getOriginal(); +std::shared_ptr MHA::get_subgraph() { + return std::make_shared(inputDynamicShapes, m_input_types, m_with_mul); } void MHASelect::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -99,39 +99,36 @@ void MHASelect::generate_inputs(const std::vector& targetInputSta } } -void MHASelect::init_subgraph() { - auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, m_input_types); - function = f.getOriginal(); +std::shared_ptr MHASelect::get_subgraph() { + return std::make_shared(inputDynamicShapes, m_input_types); } -void MHAWOTransposeOnInputs::init_subgraph() { - auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAWOTransposeOnInputs::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAWOTranspose::init_subgraph() { - auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes, m_input_types); - function = f.getOriginal(); +std::shared_ptr MHAWOTranspose::get_subgraph() { + return std::make_shared(inputDynamicShapes, m_input_types); } -void MHAINT8MatMul::init_subgraph() { - auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAINT8MatMul::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAFQAfterMatMul::init_subgraph() { - auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAFQAfterMatMul::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAFQ::init_subgraph() { - auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAFQ::get_subgraph() { + return std::make_shared(inputDynamicShapes); } -void MHAMulAdd::init_subgraph() { - auto f = ov::test::snippets::MHAMulAddFunction(inputDynamicShapes); - function = f.getOriginal(); +std::shared_ptr MHAMulAdd::get_subgraph() { + return std::make_shared(inputDynamicShapes); +} + +std::shared_ptr MHATransposedB::get_subgraph() { + return std::make_shared(inputDynamicShapes, true); } TEST_P(MHA, CompareWithRefImpl) { @@ -153,26 +150,37 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) { } TEST_P(MHAWOTranspose, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAMulAdd, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(MHATransposedB, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAINT8MatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MHAFQ, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp index 745d5e990f3..f3392902094 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp @@ -289,6 +289,31 @@ protected: std::shared_ptr initOriginal() const override; }; +/* Graph: + * Transpose/Parameter + * \ / + * MatMul0 [transposed_b = true/false] + * | + * Softmax + * \ / + * MatMul1 + * | + */ +class MHATransposedInputFunction : public SnippetsFunctionBase { +public: + explicit MHATransposedInputFunction(const std::vector& inputShapes, bool transposed_b = false, + std::vector order = {}) + : SnippetsFunctionBase(inputShapes), m_transposed_b(transposed_b), m_order(order) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; + std::shared_ptr initReference() const override; + + bool m_transposed_b = false; + std::vector m_order = {}; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index fdd9fd3a9c1..440e3607a2f 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -685,6 +685,72 @@ std::shared_ptr MHAMulAddFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } +std::shared_ptr MHATransposedInputFunction::initOriginal() const { + const auto param0 = std::make_shared(precision, input_shapes[0]); + const auto param1 = std::make_shared(precision, input_shapes[1]); + const auto param2 = std::make_shared(precision, input_shapes[2]); + ngraph::ParameterVector ngraphParam = {param0, param1, param2}; + + std::shared_ptr matmul0_in1 = param1; + if (!m_order.empty()) { + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order); + matmul0_in1 = std::make_shared(param1, transposeConst); + } + + const auto matMul0 = std::make_shared(param0, matmul0_in1, false, m_transposed_b); + const auto softmax = std::make_shared(matMul0, -1); + const auto matMul1 = std::make_shared(softmax, param2); + + ngraph::ResultVector results{std::make_shared(matMul1)}; + return std::make_shared(results, ngraphParam, "mha"); +} + +std::shared_ptr MHATransposedInputFunction::initReference() const { + const auto data0 = std::make_shared(precision, input_shapes[0]); + const auto data1 = std::make_shared(precision, input_shapes[1]); + const auto data2 = std::make_shared(precision, input_shapes[2]); + ngraph::ParameterVector ngraphParam = {data0, data1, data2}; + + bool is_supported = ((m_transposed_b && m_order == std::vector{0, 2, 1, 3}) || + (!m_transposed_b && m_order == std::vector{0, 2, 3, 1})); + + std::shared_ptr in1 = data1; + if (!m_order.empty() && !is_supported) { + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order); + in1 = std::make_shared(in1, transposeConst); + } + if (m_transposed_b) { + if (m_order != std::vector{0, 2, 1, 3}) { + const auto rank = input_shapes[1].size(); + std::vector transpose_order(rank, 0); + std::iota(transpose_order.begin(), transpose_order.end(), 0); + std::swap(transpose_order[rank - 1], transpose_order[rank - 2]); + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{transpose_order.size()}, transpose_order); + in1 = std::make_shared(in1, transposeConst); + } + } + + const auto param0 = std::make_shared(precision, data0->get_shape()); + const auto param1 = std::make_shared(precision, in1->get_shape()); + const auto param2 = std::make_shared(precision, data2->get_shape()); + + std::shared_ptr matmul0_in1 = param1; + if (!m_order.empty() && is_supported) { + const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{m_order.size()}, m_order); + matmul0_in1 = std::make_shared(param1, transposeConst); + } + + const auto matMul0 = std::make_shared(param0, matmul0_in1); + const auto softmax = std::make_shared(matMul0, -1); + const auto matMul1 = std::make_shared(softmax, param2); + + auto subgraph = std::make_shared(ov::NodeVector{data0, in1, data2}, + std::make_shared(NodeVector{matMul1}, ov::ParameterVector{param0, param1, param2})); + + ngraph::ResultVector results{std::make_shared(subgraph)}; + return std::make_shared(results, ngraphParam, "mha"); +} + } // namespace snippets } // namespace test } // namespace ov