diff --git a/src/inference/dev_api/ie_system_conf.h b/src/inference/dev_api/ie_system_conf.h index 993a8362bb3..8c6cc6abb9d 100644 --- a/src/inference/dev_api/ie_system_conf.h +++ b/src/inference/dev_api/ie_system_conf.h @@ -97,6 +97,13 @@ INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512f(); */ INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core(); +/** + * @brief Checks whether CPU supports AVX 512 VNNI capability + * @ingroup ie_dev_api_system_conf + * @return `True` is AVX512F, AVX512BW, AVX512DQ, AVX512_VNNI instructions are available, `false` otherwise + */ +INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core_vnni(); + /** * @brief Checks whether CPU supports BFloat16 capability * @ingroup ie_dev_api_system_conf diff --git a/src/inference/src/ie_system_conf.cpp b/src/inference/src/ie_system_conf.cpp index e39918c675a..4e6a42536ee 100644 --- a/src/inference/src/ie_system_conf.cpp +++ b/src/inference/src/ie_system_conf.cpp @@ -41,6 +41,10 @@ bool with_cpu_x86_avx512_core() { return get_cpu_info().has(Xbyak::util::Cpu::tAVX512F | Xbyak::util::Cpu::tAVX512DQ | Xbyak::util::Cpu::tAVX512BW); } +bool with_cpu_x86_avx512_core_vnni() { + return with_cpu_x86_avx512_core() && get_cpu_info().has(Xbyak::util::Cpu::tAVX512_VNNI); +} + bool with_cpu_x86_bfloat16() { return get_cpu_info().has(Xbyak::util::Cpu::tAVX512_BF16); } diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index e2f8d3a720c..d16e640c4f1 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -197,6 +197,7 @@ const InferenceEngine::details::caseless_unordered_map type_t { "Subgraph", Type::Subgraph}, { "PriorBox", Type::PriorBox}, { "PriorBoxClustered", Type::PriorBoxClustered}, + { "MHA", Type::MHA}, }; Type TypeFromName(const std::string& type) { @@ -388,6 +389,8 @@ std::string NameFromType(const Type type) { return "Reference"; case Type::Subgraph: return "Subgraph"; + case Type::MHA: + return "MHA"; default: return "Unknown"; } diff --git a/src/plugins/intel_cpu/src/cpu_types.h b/src/plugins/intel_cpu/src/cpu_types.h index a5a680969cc..9d1c531b03c 100644 --- a/src/plugins/intel_cpu/src/cpu_types.h +++ b/src/plugins/intel_cpu/src/cpu_types.h @@ -108,6 +108,7 @@ enum class Type { Subgraph, PriorBox, PriorBoxClustered, + MHA }; enum class Algorithm { diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index 7094a81243b..ee22bf442e3 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -7,6 +7,7 @@ #include "ngraph_transformations/op/leaky_relu.hpp" #include "ngraph_transformations/op/power_static.hpp" #include "ngraph_transformations/op/swish_cpu.hpp" +#include "ngraph_transformations/op/mha.hpp" #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" @@ -44,6 +45,7 @@ std::map Extension::getOpSets() { NGRAPH_OP(LeakyReluNode, ov::intel_cpu) NGRAPH_OP(PowerStaticNode, ov::intel_cpu) NGRAPH_OP(SwishNode, ov::intel_cpu) + NGRAPH_OP(MHANode, ov::intel_cpu) NGRAPH_OP(LoadConvertSaturation, ov::intel_cpu) NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu) NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu) diff --git a/src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.cpp b/src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.cpp new file mode 100644 index 00000000000..86fd4f5cdab --- /dev/null +++ b/src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.cpp @@ -0,0 +1,644 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "mha_fusion.hpp" + +#include +#include +#include +#include +#include "op/mha.hpp" + +#include "itt.hpp" + +// TODO: draw pattern +ov::intel_cpu::MHAFloatFusion::MHAFloatFusion() { + MATCHER_SCOPE(MHAFloatFusion); + + auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in2 = ngraph::pattern::wrap_type(); + auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in4 = ngraph::pattern::wrap_type(); + auto in5 = ngraph::pattern::wrap_type(); + auto in6 = ngraph::pattern::wrap_type(); + auto in7 = ngraph::pattern::wrap_type(); + auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in9 = ngraph::pattern::wrap_type(); + auto in10 = ngraph::pattern::wrap_type(); + auto transpose0 = std::make_shared(in0, in4); + auto transpose1 = std::make_shared(in1, in5); + auto mul = std::make_shared(transpose1, in2); + auto matmul0 = std::make_shared(transpose0, mul); + auto add = std::make_shared(matmul0, in3); + auto reshape0 = std::make_shared(add, in6, true); + auto softmax = std::make_shared(reshape0); + auto reshape1 = std::make_shared(softmax, in7, true); + auto transpose2 = std::make_shared(in8, in9); + auto matmul1 = std::make_shared(reshape1, transpose2); + auto transpose3 = std::make_shared(matmul1, in10); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose0_in = pattern_to_output.at(in0); + auto transpose1_in = pattern_to_output.at(in1); + auto mul_in1 = pattern_to_output.at(in2); + auto add_in1 = pattern_to_output.at(in3); + auto transpose2_in = pattern_to_output.at(in8); + + if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) { + return false; + } + + if (transpose0_in.get_shape().size() != 4) { + return false; + } + + auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]}); + if (add_in1.get_shape() != expected_add_shape) { + return false; + } + + if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false; + if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + + std::vector mul_scales; + if (auto mul_node = ngraph::as_type_ptr(pattern_to_output.at(mul).get_node_shared_ptr())) { + mul_scales = ngraph::as_type_ptr(mul_node->get_input_node_shared_ptr(1))->cast_vector(); + + auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1}); + if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) { + return false; + } + } else { + return false; + } + + auto matmul0_node = ngraph::as_type_ptr(pattern_to_output.at(matmul0).get_node_shared_ptr()); + if (!matmul0_node) + return false; + if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b()) + return false; + + auto reshape0_node = ngraph::as_type_ptr(pattern_to_output.at(reshape0).get_node_shared_ptr()); + if (!reshape0_node) + return false; + + if (auto reshape_pattern = ngraph::as_type_ptr(pattern_to_output.at(in6).get_node_shared_ptr())) { + if (reshape0_node->get_input_shape(0).size() != 4) { + return false; + } + + std::vector reshapeConstData = {static_cast(reshape0_node->get_input_shape(0)[0] * + reshape0_node->get_input_shape(0)[1] * + reshape0_node->get_input_shape(0)[2]), + -1}; + + if (reshape_pattern->cast_vector() != reshapeConstData) { + return false; + } + } else { + return false; + } + + if (auto reshape1_node = ngraph::as_type_ptr(pattern_to_output.at(reshape1).get_node_shared_ptr())) { + if (reshape0_node->get_input_shape(0) != reshape1_node->get_output_shape(0)) { + return false; + } + } else { + return false; + } + + auto softmax_node = ngraph::as_type_ptr(pattern_to_output.at(softmax).get_node_shared_ptr()); + if (!softmax_node) + return false; + if (softmax_node->get_axis() != 1) + return false; + + auto matmul1_node = ngraph::as_type_ptr(pattern_to_output.at(matmul1).get_node_shared_ptr()); + if (!matmul1_node) + return false; + if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b()) + return false; + + bool is_mul_first = true; + auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr(); + auto mha = std::make_shared(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first, + transpose3_node->get_output_element_type(0)); + mha->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(), + pattern_to_output.at(transpose1).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(matmul0).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(reshape0).get_node_shared_ptr(), + pattern_to_output.at(softmax).get_node_shared_ptr(), + pattern_to_output.at(reshape1).get_node_shared_ptr(), + pattern_to_output.at(transpose2).get_node_shared_ptr(), + pattern_to_output.at(matmul1).get_node_shared_ptr(), + pattern_to_output.at(transpose3).get_node_shared_ptr(), + }, + mha); + + if (transformation_callback(mha)) { + return false; + } + + ngraph::replace_node(m.get_match_root(), mha); + + return true; + }; + + auto m = std::make_shared(transpose3, matcher_name); + this->register_matcher(m, callback); +} + +ov::intel_cpu::MHAFloatFusion2::MHAFloatFusion2() { + MATCHER_SCOPE(MHAFloatFusion2); + + auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in4 = ngraph::pattern::wrap_type(); + auto in5 = ngraph::pattern::wrap_type(); + auto in6 = ngraph::pattern::wrap_type(); + auto in7 = ngraph::pattern::wrap_type(); + auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in9 = ngraph::pattern::wrap_type(); + auto in10 = ngraph::pattern::wrap_type(); + auto transpose0 = std::make_shared(in0, in4); + auto transpose1 = std::make_shared(in1, in5); + auto matmul0 = std::make_shared(transpose0, transpose1); + auto add = std::make_shared(matmul0, in3); + auto softmax = std::make_shared(add); + auto transpose2 = std::make_shared(in8, in9); + auto matmul1 = std::make_shared(softmax, transpose2); + auto transpose3 = std::make_shared(matmul1, in10); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose0_in = pattern_to_output.at(in0); + auto transpose1_in = pattern_to_output.at(in1); + auto add_in1 = pattern_to_output.at(in3); + auto transpose2_in = pattern_to_output.at(in8); + + if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) { + return false; + } + + if (transpose0_in.get_shape().size() != 4) { + return false; + } + + auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]}); + if (add_in1.get_shape() != expected_add_shape) { + return false; + } + + if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false; + if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + + auto matmul0_node = ngraph::as_type_ptr(pattern_to_output.at(matmul0).get_node_shared_ptr()); + if (!matmul0_node) + return false; + if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b()) + return false; + + auto softmax_node = ngraph::as_type_ptr(pattern_to_output.at(softmax).get_node_shared_ptr()); + if (!softmax_node) + return false; + if (softmax_node->get_axis() != 3) + return false; + + auto matmul1_node = ngraph::as_type_ptr(pattern_to_output.at(matmul1).get_node_shared_ptr()); + if (!matmul1_node) + return false; + if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b()) + return false; + + auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr(); + auto mha = std::make_shared(transpose0_in, transpose1_in, add_in1, transpose2_in, std::vector(), false, + transpose3_node->get_output_element_type(0)); + mha->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(), + pattern_to_output.at(transpose1).get_node_shared_ptr(), + pattern_to_output.at(matmul0).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(softmax).get_node_shared_ptr(), + pattern_to_output.at(transpose2).get_node_shared_ptr(), + pattern_to_output.at(matmul1).get_node_shared_ptr(), + pattern_to_output.at(transpose3).get_node_shared_ptr(), + }, + mha); + + if (transformation_callback(mha)) { + return false; + } + + ngraph::replace_node(m.get_match_root(), mha); + + return true; + }; + + auto m = std::make_shared(transpose3, matcher_name); + this->register_matcher(m, callback); +} + +static std::vector simplifyToScale(const std::shared_ptr& fq_node) { + auto levels = fq_node->get_levels(); + auto input_low = ngraph::as_type_ptr(fq_node->get_input_node_shared_ptr(1))->cast_vector(); + auto input_high = ngraph::as_type_ptr(fq_node->get_input_node_shared_ptr(2))->cast_vector(); + auto output_low = ngraph::as_type_ptr(fq_node->get_input_node_shared_ptr(3))->cast_vector(); + auto output_high = ngraph::as_type_ptr(fq_node->get_input_node_shared_ptr(4))->cast_vector(); + + std::vector cl, ch, isc, ish, osc, osh; + for (int i = 0; i < input_low.size(); i++) { + cl.push_back(input_low[i]); + } + for (int i = 0; i < input_high.size(); i++) { + ch.push_back(input_high[i]); + } + + for (int i = 0; i < std::max(input_low.size(), input_high.size()); i++) { + float il = input_low[input_low.size() == 1 ? 0 : i]; + float ih = input_high[input_high.size() == 1 ? 0 : i]; + + isc.push_back((levels - 1) / (ih - il)); + ish.push_back(-il * (levels - 1) / (ih - il)); + } + + for (int i = 0; i < std::max(output_low.size(), output_high.size()); i++) { + float ol = output_low[output_low.size() == 1 ? 0 : i]; + float oh = output_high[output_high.size() == 1 ? 0 : i]; + + osc.push_back((oh - ol) / (levels - 1)); + osh.push_back(ol); + } + + std::vector outScale; + + if (fq_node->get_output_element_type(0) == ngraph::element::u8 && + std::all_of(cl.cbegin(), cl.cend(), [](float val) { return val == 0.0f; }) && + std::all_of(ish.cbegin(), ish.cend(), [](float val) { return val == 0.0f; }) && + std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.0f; }) && + std::all_of(osh.cbegin(), osh.cend(), [](float val) { return val == 0.0f; })) { + outScale = isc; + } + + if (fq_node->get_output_element_type(0) == ngraph::element::i8 && + std::all_of(ish.cbegin(), ish.cend(), [](float val) { return std::abs(val - 128.f) < 0.0001f; }) && + std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.f; }) && + std::all_of(osh.cbegin(), osh.cend(), [](float val) { return std::abs(val + 128.f) < 0.0001f; })) { + bool isCropAligned = true; + for (int i = 0; i < std::max(cl.size(), isc.size()); i++) { + if (std::abs(cl[cl.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] + 128.f) > 0.0001f) { + isCropAligned = false; + } + } + + for (int i = 0; i < std::max(ch.size(), isc.size()); i++) { + if (std::abs(ch[ch.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] - 127.f) > 0.0001f) { + isCropAligned = false; + } + } + + if (isCropAligned) { + outScale = isc; + } + } + + return outScale; +} + +// TODO: draw pattern +ov::intel_cpu::MHAQuantFusion::MHAQuantFusion() { + MATCHER_SCOPE(MHAQuantFusion); + + auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in2 = ngraph::pattern::wrap_type(); + auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in4 = ngraph::pattern::wrap_type(); + auto in5 = ngraph::pattern::wrap_type(); + auto in6 = ngraph::pattern::wrap_type(); + auto in7 = ngraph::pattern::wrap_type(); + auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in9 = ngraph::pattern::wrap_type(); + auto in10 = ngraph::pattern::wrap_type(); + auto transpose0 = std::make_shared(in0, in4); + auto transpose1 = std::make_shared(in1, in5); + auto matmul0 = std::make_shared(transpose0, transpose1); + auto fakeQuantize0 = ngraph::pattern::wrap_type({matmul0, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto add = std::make_shared(fakeQuantize0, in3); + auto mul = std::make_shared(add, in2); + auto reshape0 = std::make_shared(mul, in6, true); + auto softmax = std::make_shared(reshape0); + auto reshape1 = std::make_shared(softmax, in7, true); + auto fakeQuantize1 = ngraph::pattern::wrap_type({reshape1, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto transpose2 = std::make_shared(in8, in9); + auto matmul1 = std::make_shared(fakeQuantize1, transpose2); + auto fakeQuantize2 = ngraph::pattern::wrap_type({matmul1, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto transpose3 = std::make_shared(fakeQuantize2, in10); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose0_in = pattern_to_output.at(in0); + auto transpose1_in = pattern_to_output.at(in1); + auto add_in1 = pattern_to_output.at(in3); + auto transpose2_in = pattern_to_output.at(in8); + + if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) { + return false; + } + + if (transpose0_in.get_shape().size() != 4) { + return false; + } + + auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]}); + if (add_in1.get_shape() != expected_add_shape) { + return false; + } + + std::vector mul_scales; + if (auto mul_node = ngraph::as_type_ptr(pattern_to_output.at(mul).get_node_shared_ptr())) { + mul_scales = ngraph::as_type_ptr(mul_node->get_input_node_shared_ptr(1))->cast_vector(); + + auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1}); + if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) { + return false; + } + } else { + return false; + } + + if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false; + if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + + auto matmul0_node = ngraph::as_type_ptr(pattern_to_output.at(matmul0).get_node_shared_ptr()); + if (!matmul0_node) + return false; + if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b()) + return false; + + std::vector fq0_scale; + auto fq0_node = ngraph::as_type_ptr(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr()); + if (fq0_node) { + fq0_scale = simplifyToScale(fq0_node); + if (!fq0_scale.size()) + return false; + } + + auto reshape0_node = ngraph::as_type_ptr(pattern_to_output.at(reshape0).get_node_shared_ptr()); + if (!reshape0_node) + return false; + + if (auto reshape_pattern = ngraph::as_type_ptr(pattern_to_output.at(in6).get_node_shared_ptr())) { + if (reshape0_node->get_input_shape(0).size() != 4) { + return false; + } + + std::vector reshapeConstData = {static_cast(reshape0_node->get_input_shape(0)[0] * + reshape0_node->get_input_shape(0)[1] * + reshape0_node->get_input_shape(0)[2]), + -1}; + + if (reshape_pattern->cast_vector() != reshapeConstData) { + return false; + } + } else { + return false; + } + + if (auto reshape1_node = ngraph::as_type_ptr(pattern_to_output.at(reshape1).get_node_shared_ptr())) { + if (reshape0_node->get_input_shape(0) != reshape1_node->get_output_shape(0)) { + return false; + } + } else { + return false; + } + + auto softmax_node = ngraph::as_type_ptr(pattern_to_output.at(softmax).get_node_shared_ptr()); + if (!softmax_node) + return false; + if (softmax_node->get_axis() != 1) + return false; + + std::vector fq1_scale; + auto fq1_node = ngraph::as_type_ptr(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr()); + if (fq1_node) { + fq1_scale = simplifyToScale(fq1_node); + if (!fq1_scale.size()) + return false; + } else { + return false; + } + + auto matmul1_node = ngraph::as_type_ptr(pattern_to_output.at(matmul1).get_node_shared_ptr()); + if (!matmul1_node) + return false; + if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b()) + return false; + + std::vector fq2_scale; + if (auto fq_node = ngraph::as_type_ptr(pattern_to_output.at(fakeQuantize2).get_node_shared_ptr())) { + fq2_scale = simplifyToScale(fq_node); + if (!fq2_scale.size()) + return false; + } + + bool is_mul_first = false; + auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr(); + auto mha = std::make_shared(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first, + std::vector(), fq0_scale, fq1_scale, fq2_scale, + ngraph::element::undefined, + fq0_node ? fq0_node->get_output_element_type(0) : ngraph::element::undefined, + fq1_node->get_output_element_type(0), transpose3_node->get_output_element_type(0)); + mha->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(), + pattern_to_output.at(transpose1).get_node_shared_ptr(), + pattern_to_output.at(matmul0).get_node_shared_ptr(), + pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(reshape0).get_node_shared_ptr(), + pattern_to_output.at(softmax).get_node_shared_ptr(), + pattern_to_output.at(reshape1).get_node_shared_ptr(), + pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(), + pattern_to_output.at(transpose2).get_node_shared_ptr(), + pattern_to_output.at(matmul1).get_node_shared_ptr(), + pattern_to_output.at(fakeQuantize2).get_node_shared_ptr(), + pattern_to_output.at(transpose3).get_node_shared_ptr(), + }, + mha); + + if (transformation_callback(mha)) { + return false; + } + + ngraph::replace_node(m.get_match_root(), mha); + + return true; + }; + + auto m = std::make_shared(transpose3, matcher_name); + this->register_matcher(m, callback); +} + +// TODO: draw pattern +ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() { + MATCHER_SCOPE(MHAQuantFusion2); + + auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in2 = ngraph::pattern::wrap_type(); + auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in4 = ngraph::pattern::wrap_type(); + auto in5 = ngraph::pattern::wrap_type(); + auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto in9 = ngraph::pattern::wrap_type(); + auto in10 = ngraph::pattern::wrap_type(); + auto transpose0 = std::make_shared(in0, in4); + auto transpose1 = std::make_shared(in1, in5); + auto fakeQuantize0 = ngraph::pattern::wrap_type({transpose1, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto matmul0 = std::make_shared(transpose0, fakeQuantize0); + auto mul = std::make_shared(matmul0, in2); + auto add = std::make_shared(mul, in3); + auto softmax = std::make_shared(add); + auto transpose2 = std::make_shared(in8, in9); + auto matmul1 = std::make_shared(softmax, transpose2); + auto fakeQuantize1 = ngraph::pattern::wrap_type({matmul1, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto transpose3 = std::make_shared(fakeQuantize1, in10); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose0_in = pattern_to_output.at(in0); + auto transpose1_in = pattern_to_output.at(in1); + auto add_in1 = pattern_to_output.at(in3); + auto transpose2_in = pattern_to_output.at(in8); + + if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) { + return false; + } + + if (transpose0_in.get_shape().size() != 4) { + return false; + } + + auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]}); + if (add_in1.get_shape() != expected_add_shape) { + return false; + } + + std::vector mul_scales; + if (auto mul_node = ngraph::as_type_ptr(pattern_to_output.at(mul).get_node_shared_ptr())) { + mul_scales = ngraph::as_type_ptr(mul_node->get_input_node_shared_ptr(1))->cast_vector(); + + auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1}); + if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) { + return false; + } + } else { + return false; + } + + if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false; + if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false; + + auto matmul0_node = ngraph::as_type_ptr(pattern_to_output.at(matmul0).get_node_shared_ptr()); + if (!matmul0_node) + return false; + if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b()) + return false; + + std::vector fq0_scale; + auto fq0_node = ngraph::as_type_ptr(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr()); + if (fq0_node) { + fq0_scale = simplifyToScale(fq0_node); + if (!fq0_scale.size()) + return false; + } else { + return false; + } + + auto softmax_node = ngraph::as_type_ptr(pattern_to_output.at(softmax).get_node_shared_ptr()); + if (!softmax_node) + return false; + if (softmax_node->get_axis() != 3) + return false; + + std::vector fq1_scale; + if (auto fq_node = ngraph::as_type_ptr(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) { + fq1_scale = simplifyToScale(fq_node); + if (!fq1_scale.size()) + return false; + } + + auto matmul1_node = ngraph::as_type_ptr(pattern_to_output.at(matmul1).get_node_shared_ptr()); + if (!matmul1_node) + return false; + if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b()) + return false; + + bool is_mul_first = true; + auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr(); + auto mha = std::make_shared(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first, + fq0_scale, std::vector(), std::vector(), fq1_scale, + fq0_node->get_output_element_type(0), ngraph::element::undefined, ngraph::element::undefined, + transpose3_node->get_output_element_type(0)); + mha->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(), + pattern_to_output.at(transpose1).get_node_shared_ptr(), + pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(), + pattern_to_output.at(matmul0).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(softmax).get_node_shared_ptr(), + pattern_to_output.at(transpose2).get_node_shared_ptr(), + pattern_to_output.at(matmul1).get_node_shared_ptr(), + pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(), + pattern_to_output.at(transpose3).get_node_shared_ptr(), + }, + mha); + + if (transformation_callback(mha)) { + return false; + } + + ngraph::replace_node(m.get_match_root(), mha); + + return true; + }; + + auto m = std::make_shared(transpose3, matcher_name); + this->register_matcher(m, callback); +} diff --git a/src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.hpp b/src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.hpp new file mode 100644 index 00000000000..18b8469d54e --- /dev/null +++ b/src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace ov { +namespace intel_cpu { + +class MHAFusionBase : public ngraph::pass::MatcherPass { +protected: + bool valid_transpose_order(const std::shared_ptr& node, const std::vector& expected_order) { + if (auto transpose_pattern = ngraph::as_type_ptr(node)) { + if (transpose_pattern->cast_vector() != expected_order) { + return false; + } + } else { + return false; + } + + return true; + } +}; + +class MHAFloatFusion: public MHAFusionBase { +public: + OPENVINO_RTTI("MHAFloatFusion", "0"); + MHAFloatFusion(); +}; + +class MHAFloatFusion2: public MHAFusionBase { +public: + OPENVINO_RTTI("MHAFloatFusion2", "0"); + MHAFloatFusion2(); +}; + +class MHAQuantFusion: public MHAFusionBase { +public: + OPENVINO_RTTI("MHAQuantFusion", "0"); + MHAQuantFusion(); +}; + +class MHAQuantFusion2: public MHAFusionBase { +public: + OPENVINO_RTTI("MHAQuantFusion2", "0"); + MHAQuantFusion2(); +}; + +class MHAFusion : public ngraph::pass::GraphRewrite { +public: + OPENVINO_RTTI("MHAFusion", "0"); + MHAFusion() { + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + } +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/ngraph_transformations/op/mha.cpp b/src/plugins/intel_cpu/src/ngraph_transformations/op/mha.cpp new file mode 100644 index 00000000000..9a1a2584a27 --- /dev/null +++ b/src/plugins/intel_cpu/src/ngraph_transformations/op/mha.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "mha.hpp" +#include "../itt.hpp" +#include +#include + +ov::intel_cpu::MHANode::MHANode(const ngraph::Output &in0, + const ngraph::Output &in1, + const ngraph::Output &in2, + const ngraph::Output &in3, + const std::vector &mul_scales, + bool is_mul_first, + const ngraph::element::Type output_type) + : Op({in0, in1, in2, in3}), m_output_type(output_type) { + this->mul_scales = mul_scales; + this->is_mul_first = is_mul_first; + this->fq0_output_type = ngraph::element::undefined; + this->fq1_output_type = ngraph::element::undefined; + this->fq2_output_type = ngraph::element::undefined; + validate_and_infer_types(); +} + +ov::intel_cpu::MHANode::MHANode(const ngraph::Output &in0, + const ngraph::Output &in1, + const ngraph::Output &in2, + const ngraph::Output &in3, + const std::vector &mul_scales, + bool is_mul_first, + const std::vector &fq_scales0, + const std::vector &fq_scales1, + const std::vector &fq_scales2, + const std::vector &fq_scales3, + const ngraph::element::Type fq0_output_type, + const ngraph::element::Type fq1_output_type, + const ngraph::element::Type fq2_output_type, + const ngraph::element::Type output_type) + : Op({in0, in1, in2, in3}), m_output_type(output_type) { + this->mul_scales = mul_scales; + this->is_mul_first = is_mul_first; + this->fq_scales0 = fq_scales0; + this->fq_scales1 = fq_scales1; + this->fq_scales2 = fq_scales2; + this->fq_scales3 = fq_scales3; + this->fq0_output_type = fq0_output_type; + this->fq1_output_type = fq1_output_type; + this->fq2_output_type = fq2_output_type; + validate_and_infer_types(); +} + +std::shared_ptr ov::intel_cpu::MHANode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const { + INTERNAL_OP_SCOPE(MHANode_clone_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), + mul_scales, is_mul_first, fq_scales0, fq_scales1, fq_scales2, fq_scales3, + fq0_output_type, fq1_output_type, fq2_output_type, m_output_type); +} + +void ov::intel_cpu::MHANode::validate_and_infer_types() { + INTERNAL_OP_SCOPE(MHANode_validate_and_infer_types); + + auto transpose = [](const ov::Shape& shape, const std::vector& order) -> ov::Shape { + std::vector new_shape(shape.size()); + for (int i = 0; i < shape.size(); i++) { + new_shape[i] = shape[order[i]]; + } + return new_shape; + }; + + const auto matmul0_shape0 = transpose(get_input_partial_shape(0).get_shape(), {0, 2, 1, 3}); + const auto matmul0_shape1 = transpose(get_input_partial_shape(1).get_shape(), {0, 2, 3, 1}); + + auto matmul0_in0 = std::make_shared(ngraph::element::f32, matmul0_shape0); + auto matmul0_in1 = std::make_shared(ngraph::element::f32, matmul0_shape1); + auto matmul0 = std::make_shared(matmul0_in0, matmul0_in1); + + std::vector matmul0_input_shapes = {matmul0_shape0, matmul0_shape1}; + std::vector matmul0_output_shapes = {ov::PartialShape{}}; + + shape_infer(matmul0.get(), matmul0_input_shapes, matmul0_output_shapes); + + const auto matmul1_shape0 = matmul0_output_shapes[0]; + const auto matmul1_shape1 = transpose(get_input_partial_shape(3).get_shape(), {0, 2, 1, 3}); + + auto matmul1_in0 = std::make_shared(ngraph::element::f32, matmul1_shape0); + auto matmul1_in1 = std::make_shared(ngraph::element::f32, matmul1_shape1); + auto matmul1 = std::make_shared(matmul1_in0, matmul1_in1); + + std::vector matmul1_input_shapes = {matmul1_shape0, matmul1_shape1}; + std::vector matmul1_output_shapes = {ov::PartialShape{}}; + + shape_infer(matmul1.get(), matmul1_input_shapes, matmul1_output_shapes); + + const auto output_shape = transpose(matmul1_output_shapes[0].get_shape(), {0, 2, 1, 3}); + + set_output_type( + 0, + m_output_type == ngraph::element::undefined || m_output_type == ngraph::element::dynamic ? get_input_element_type(0) : m_output_type, + output_shape); +} + +bool ov::intel_cpu::MHANode::visit_attributes(ngraph::AttributeVisitor &visitor) { + INTERNAL_OP_SCOPE(MHANode_visit_attributes); + visitor.on_attribute("out-type", m_output_type); + return true; +} diff --git a/src/plugins/intel_cpu/src/ngraph_transformations/op/mha.hpp b/src/plugins/intel_cpu/src/ngraph_transformations/op/mha.hpp new file mode 100644 index 00000000000..095ad8e5bc9 --- /dev/null +++ b/src/plugins/intel_cpu/src/ngraph_transformations/op/mha.hpp @@ -0,0 +1,94 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace intel_cpu { + +class MHANode : public ngraph::op::Op { +public: + OPENVINO_OP("MHA", "cpu_plugin_opset"); + + MHANode() = default; + + MHANode(const ngraph::Output &in0, + const ngraph::Output &in1, + const ngraph::Output &in2, + const ngraph::Output &in3, + const std::vector &mul_scales, + bool is_mul_first, + const ngraph::element::Type output_type); + + MHANode(const ngraph::Output &in0, + const ngraph::Output &in1, + const ngraph::Output &in2, + const ngraph::Output &in3, + const std::vector &mul_scales, + bool is_mul_first, + const std::vector &fq_scales0, + const std::vector &fq_scales1, + const std::vector &fq_scales2, + const std::vector &fq_scales3, + const ngraph::element::Type fq0_output_type, + const ngraph::element::Type fq1_output_type, + const ngraph::element::Type fq2_output_type, + const ngraph::element::Type output_type); + + void validate_and_infer_types() override; + + bool visit_attributes(ngraph::AttributeVisitor &visitor) override; + + std::shared_ptr clone_with_new_inputs(const ngraph::OutputVector &new_args) const override; + + ngraph::element::Type get_output_type() const { return m_output_type; } + + const std::vector& get_mul_scales() const { + return mul_scales; + } + + const std::vector& get_fq_scales0() const { + return fq_scales0; + } + const std::vector& get_fq_scales1() const { + return fq_scales1; + } + const std::vector& get_fq_scales2() const { + return fq_scales2; + } + const std::vector& get_fq_scales3() const { + return fq_scales3; + } + + bool get_is_mul_first() const { + return is_mul_first; + } + + ngraph::element::Type get_fq0_output_type() const { + return fq0_output_type; + } + ngraph::element::Type get_fq1_output_type() const { + return fq1_output_type; + } + ngraph::element::Type get_fq2_output_type() const { + return fq2_output_type; + } + +private: + ngraph::element::Type m_output_type; + std::vector mul_scales; + bool is_mul_first; + std::vector fq_scales0; + std::vector fq_scales1; + std::vector fq_scales2; + std::vector fq_scales3; + ngraph::element::Type fq0_output_type; + ngraph::element::Type fq1_output_type; + ngraph::element::Type fq2_output_type; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/mha.cpp b/src/plugins/intel_cpu/src/nodes/mha.cpp new file mode 100644 index 00000000000..11d49ecd6ff --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/mha.cpp @@ -0,0 +1,1404 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ie_parallel.hpp" +#include "mha.h" +#include +#include "common/cpu_memcpy.h" +#include +#include +#include "emitters/jit_dnnl_emitters.hpp" +#include "emitters/jit_load_store_emitters.hpp" +#include "common/cpu_convert.h" +#include "ngraph_transformations/op/mha.hpp" +#include "dnnl_extension_utils.h" +#include + +using namespace InferenceEngine; +using namespace InferenceEngine::details; +using namespace dnnl::impl::cpu::x64; +using namespace dnnl::impl::cpu::x64::matmul; +using namespace Xbyak; + +#define THROW_ERROR IE_THROW() << getTypeStr() << " node with name '" << getName() << "' " + +namespace ov { +namespace intel_cpu { +namespace node { + +template +struct jit_mul_add_softmax_kernel : public jit_uni_mul_add_softmax_kernel, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_mul_add_softmax_kernel) + + explicit jit_mul_add_softmax_kernel(const jit_mul_add_softmax_compile_params& jcp) : jit_uni_mul_add_softmax_kernel(jcp), jit_generator() { + exp_emitter = std::make_shared(this, isa, dnnl_eltwise_exp, 0.f, 0.f); + + vec_size = dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); + } + virtual ~jit_mul_add_softmax_kernel() {} + + void create_ker() override { + jit_generator::create_kernel(); + ker_ = (decltype(ker_))jit_ker(); + } + +private: + using Vmm = typename dnnl::impl::utils::conditional3::type; + + void generate() override { + this->preamble(); + +#define GET_OFF(field) offsetof(jit_mul_add_softmax_call_args, field) + mov(reg_in0, ptr[reg_params + GET_OFF(p_in0)]); + mov(reg_add_in1, ptr[reg_params + GET_OFF(p_add_in1)]); + mov(reg_out, ptr[reg_params + GET_OFF(p_out)]); + mov(reg_buffer, ptr[reg_params + GET_OFF(p_buffer)]); + + Xbyak::Label mul_add_max_loop_label; + Xbyak::Label mul_add_max_end_label; + Xbyak::Label sub_exp_reduce_loop_label; + Xbyak::Label sub_exp_reduce_end_label; + Xbyak::Label mul_loop_label; + Xbyak::Label mul_end_label; + + size_t tail_size = jcp_.work_amount % vec_size; + + mov(reg_buffer_aux, reg_buffer); + mov(reg_work_amount, jcp_.work_amount); + mov(reg_work_amount_aux, reg_work_amount); + uni_vpxor(get_vmm_max(0), get_vmm_max(0), get_vmm_max(0)); + + // mul1 input is const and always float + if (jcp_.with_mul_scales) { + mov(reg_mul_in1, ptr[reg_params + GET_OFF(p_mul_in1)]); + uni_vmovss(Xmm(get_vmm_in(1).getIdx()), ptr[reg_mul_in1]); + uni_vbroadcastss(get_vmm_in(1), Xmm(get_vmm_in(1).getIdx())); + } + + if (jcp_.with_scales0) { + mov(reg_scales, ptr[reg_params + GET_OFF(p_scales0)]); + + mov(reg_tmp, dnnl::impl::float2int(-128.0f)); + vmovq(xmm_tmp, reg_tmp); + vbroadcastss(vmm_crop_low, xmm_tmp); + + mov(reg_tmp, dnnl::impl::float2int(127.0f)); + vmovq(xmm_tmp, reg_tmp); + vbroadcastss(vmm_crop_high, xmm_tmp); + } + + if (jcp_.with_scales0 && jcp_.broadcast_scales0) { + uni_vmovss(Xmm(vmm_scales.getIdx()), ptr[reg_scales]); + uni_vbroadcastss(vmm_scales, Xmm(vmm_scales.getIdx())); + } + + L(mul_add_max_loop_label); + { + cmp(reg_work_amount_aux, vec_size); + jl(mul_add_max_end_label, T_NEAR); + + mul_add_max(vec_size); + + sub(reg_work_amount_aux, vec_size); + + jmp(mul_add_max_loop_label, T_NEAR); + } + L(mul_add_max_end_label); + if (tail_size) { + mul_add_max(tail_size); + } + + sub(rsp, sizeof(float) * vec_size); + uni_vmovups(ptr[rsp], get_vmm_max(0)); + uni_vpxor(get_vmm_max(0), get_vmm_max(0), get_vmm_max(0)); + for (size_t i = 0; i < vec_size; i++) { + mov(reg_tmp_32, ptr[rsp + i * sizeof(float)]); + vmovq(xmm_tmp, reg_tmp); + uni_vmaxps(get_xmm_max(0), get_xmm_max(0), xmm_tmp); + } + uni_vbroadcastss(get_vmm_max(0), get_xmm_max(0)); + add(rsp, sizeof(float) * vec_size); + + uni_vpxor(get_vmm_denom(0), get_vmm_denom(0), get_vmm_denom(0)); + mov(reg_work_amount_aux, reg_work_amount); + mov(reg_buffer_aux, reg_buffer); + L(sub_exp_reduce_loop_label); + { + cmp(reg_work_amount_aux, vec_size); + jl(sub_exp_reduce_end_label, T_NEAR); + + sub_exp_reduce(vec_size); + + sub(reg_work_amount_aux, vec_size); + + jmp(sub_exp_reduce_loop_label, T_NEAR); + } + L(sub_exp_reduce_end_label); + if (tail_size) { + sub_exp_reduce(tail_size); + } + + sub(rsp, sizeof(float) * vec_size); + uni_vmovups(ptr[rsp], get_vmm_denom(0)); + uni_vpxor(get_vmm_aux(0), get_vmm_aux(0), get_vmm_aux(0)); + for (size_t i = 0; i < vec_size; i++) { + mov(reg_tmp_32, ptr[rsp + i * sizeof(float)]); + vmovq(xmm_tmp, reg_tmp); + uni_vaddps(get_xmm_aux(0), get_xmm_aux(0), xmm_tmp); + } + vbroadcastss(get_vmm_aux(0), get_xmm_aux(0)); + add(rsp, sizeof(float) * vec_size); + + mov(reg_tmp, dnnl::impl::float2int(1.0f)); + vmovq(xmm_tmp, reg_tmp); + vbroadcastss(get_vmm_denom(0), xmm_tmp); + uni_vdivps(get_vmm_denom(0), get_vmm_denom(0), get_vmm_aux(0)); + + if (jcp_.with_scales1) + mov(reg_scales, ptr[reg_params + GET_OFF(p_scales1)]); + + if (jcp_.with_scales1 && jcp_.broadcast_scales1) { + uni_vmovss(Xmm(vmm_scales.getIdx()), ptr[reg_scales]); + uni_vbroadcastss(vmm_scales, Xmm(vmm_scales.getIdx())); + } + + mov(reg_work_amount_aux, reg_work_amount); + L(mul_loop_label); + { + cmp(reg_work_amount_aux, vec_size); + jl(mul_end_label, T_NEAR); + + mul_loop(vec_size); + + sub(reg_work_amount_aux, vec_size); + + jmp(mul_loop_label, T_NEAR); + } + L(mul_end_label); + if (tail_size) { + mul_loop(tail_size); + } + + this->postamble(); + + for (const auto& emitter : emitters) { + if (emitter.second) + emitter.second->emit_data(); + } + + exp_emitter->emit_data(); + } + + void mul_add_max(size_t step) { + bool is_tail = step < vec_size; + + load(get_vmm_in(0), reg_in0, jcp_.src_prc, step, is_tail); + load(get_vmm_in(2), reg_add_in1, Precision::FP32, step, is_tail); + + if (jcp_.with_scales0) { + if (!jcp_.broadcast_scales0) { + load(vmm_scales, reg_scales, Precision::FP32, step, is_tail); + add(reg_scales, sizeof(float) * step); + } + uni_vmulps(get_vmm_in(0), get_vmm_in(0), vmm_scales); + uni_vmaxps(get_vmm_in(0), get_vmm_in(0), vmm_crop_low); + uni_vminps(get_vmm_in(0), get_vmm_in(0), vmm_crop_high); + } + + if (jcp_.with_mul_scales) { + if (jcp_.is_mul_first) { + uni_vmulps(get_vmm_in(0), get_vmm_in(0), get_vmm_in(1)); + uni_vaddps(get_vmm_in(0), get_vmm_in(0), get_vmm_in(2)); + } else { + uni_vaddps(get_vmm_in(0), get_vmm_in(0), get_vmm_in(2)); + uni_vmulps(get_vmm_in(0), get_vmm_in(0), get_vmm_in(1)); + } + } else { + uni_vaddps(get_vmm_in(0), get_vmm_in(0), get_vmm_in(2)); + } + + uni_vmaxps(get_vmm_max(0), get_vmm_max(0), get_vmm_in(0)); + + store(reg_buffer_aux, get_vmm_in(0), Precision::FP32, step); + + if (!is_tail) { + add(reg_in0, jcp_.src_prc.size() * step); + add(reg_add_in1, sizeof(float) * step); + add(reg_buffer_aux, sizeof(float) * step); + } + } + + void sub_exp_reduce(size_t step) { + bool is_tail = step < vec_size; + + load(get_vmm_in(0), reg_buffer_aux, Precision::FP32, step, is_tail); + + uni_vsubps(get_vmm_in(0), get_vmm_in(0), get_vmm_max(0)); + + auto vmm_exp_idx = static_cast(get_vmm_in(0).getIdx()); + exp_emitter->emit_code({vmm_exp_idx}, {vmm_exp_idx}, pool_aux_vmm_idxs, pool_aux_gpr_idxs); + + uni_vaddps(get_vmm_denom(0), get_vmm_denom(0), get_vmm_in(0)); + + store(reg_buffer_aux, get_vmm_in(0), Precision::FP32, step); + + if (!is_tail) { + add(reg_buffer_aux, sizeof(float) * step); + } + } + + void mul_loop(size_t step) { + bool is_tail = step < vec_size; + + load(get_vmm_in(0), reg_buffer, Precision::FP32, step, is_tail); + + uni_vmulps(get_vmm_in(0), get_vmm_in(0), get_vmm_denom(0)); + + if (jcp_.src_prc == Precision::I32) { + if (jcp_.with_scales1) { + if (!jcp_.broadcast_scales1) { + load(vmm_scales, reg_scales, Precision::FP32, step, is_tail); + add(reg_scales, sizeof(float) * step); + } + uni_vmulps(get_vmm_in(0), get_vmm_in(0), vmm_scales); + } + } + + store(reg_out, get_vmm_in(0), jcp_.dst_prc, step); + + if (!is_tail) { + add(reg_buffer, sizeof(float) * step); + add(reg_out, jcp_.dst_prc.size() * step); + } +#undef GET_OFF + } + + inline void load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, Precision src_prc, const int& elt_num, bool fill) { + const auto seed = load_emitter_params(src_prc, Precision::FP32, elt_num, fill, "float_min").hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, Precision::FP32, elt_num, Precision::FP32, fill, "float_min")); + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), 0}, {static_cast(vmm_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); + } + inline void store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, Precision dst_prc, const int& elt_num) { + const auto seed = store_emitter_params(Precision::FP32, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_store_emitter(this, isa, Precision::FP32, dst_prc, elt_num)); + } + + emitters[seed]->emit_code({static_cast(vmm_src.getIdx()), 0}, {static_cast(reg_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); + } + + size_t unroll_factor = 3; + size_t vec_size; + + Vmm get_vmm_in(int idx) { + return Vmm(1 + 0 * unroll_factor + idx); + } + + Vmm get_vmm_aux(int idx) { + return Vmm(1 + 1 * unroll_factor + idx); + } + Xmm get_xmm_aux(int idx) { + return Xmm(1 + 1 * unroll_factor + idx); + } + + Vmm get_vmm_max(int idx) { + return Vmm(1 + 2 * unroll_factor + idx); + } + Xmm get_xmm_max(int idx) { + return Xmm(1 + 2 * unroll_factor + idx); + } + + + Vmm get_vmm_denom(int idx) { + return Vmm(1 + 3 * unroll_factor + idx); + } + + Xmm xmm_tmp = Xmm(0); + + Vmm vmm_scales = Vmm(0); + Vmm vmm_crop_low = Vmm(14); + Vmm vmm_crop_high = Vmm(15); + + Reg64 reg_in0 = r8; + Reg64 reg_mul_in1 = r9; + Reg64 reg_add_in1 = r10; + Reg64 reg_out = r11; + Reg64 reg_scales = r12; + Reg64 reg_work_amount = r13; + Reg64 reg_work_amount_aux = r14; + Reg64 reg_buffer = r15; + Reg64 reg_buffer_aux = rax; + Reg64 reg_tmp = rbx; + Reg32 reg_tmp_32 = Reg32(rbx.getIdx()); + Reg64 reg_max = rdx; + Reg32 reg_max_32 = Reg32(rdx.getIdx()); + Reg64 reg_params = abi_param1; + + const std::vector pool_aux_gpr_idxs = { static_cast(rsi.getIdx()), static_cast(rbp.getIdx()) }; + const std::vector pool_aux_vmm_idxs = { 12, 13, 14, 15 }; + + std::unordered_map> emitters; + + std::shared_ptr exp_emitter = nullptr; + std::unique_ptr load_emitter = nullptr; + std::unique_ptr store_emitter = nullptr; +}; + +template +struct jit_convert_reorder_kernel : public jit_uni_convert_reorder_kernel, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_convert_reorder_kernel) + + explicit jit_convert_reorder_kernel(const jit_convert_reorder_compile_params& jcp) : jit_uni_convert_reorder_kernel(jcp), jit_generator() { + vec_size = dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); + } + virtual ~jit_convert_reorder_kernel() {} + + void create_ker() override { + jit_generator::create_kernel(); + ker_ = (decltype(ker_))jit_ker(); + } + +private: + using Vmm = typename dnnl::impl::utils::conditional3::type; + + void generate() override { + this->preamble(); + +#define GET_OFF(field) offsetof(jit_convert_reorder_call_args, field) + mov(reg_in, ptr[reg_params + GET_OFF(p_in)]); + mov(reg_out, ptr[reg_params + GET_OFF(p_out)]); + mov(reg_outter_work_amount, ptr[reg_params + GET_OFF(outter_work_amount)]); + + if (jcp_.with_scales) { + mov(reg_scales, ptr[reg_params + GET_OFF(p_scales)]); + } + + Xbyak::Label convert_reorder_inner_loop_label; + Xbyak::Label convert_reorder_inner_end_label; + Xbyak::Label convert_reorder_outter_loop_label; + Xbyak::Label convert_reorder_outter_end_label; + + if (jcp_.with_scales && jcp_.broadcast_scales) { + uni_vmovss(Xmm(vmm_scales.getIdx()), ptr[reg_scales]); + uni_vbroadcastss(vmm_scales, Xmm(vmm_scales.getIdx())); + } + + L(convert_reorder_outter_loop_label); + { + cmp(reg_outter_work_amount, 1); + jl(convert_reorder_outter_end_label, T_NEAR); + + size_t tail_size = jcp_.inner_work_amount % vec_size; + mov(reg_inner_work_amount, jcp_.inner_work_amount); + mov(reg_in_aux, reg_in); + mov(reg_out_aux, reg_out); + if (jcp_.with_scales && !jcp_.broadcast_scales) { + mov(reg_scales, ptr[reg_params + GET_OFF(p_scales)]); + } + + L(convert_reorder_inner_loop_label); + { + cmp(reg_inner_work_amount, vec_size); + jl(convert_reorder_inner_end_label, T_NEAR); + + convert_reorder(vec_size); + + sub(reg_inner_work_amount, vec_size); + + jmp(convert_reorder_inner_loop_label, T_NEAR); + } + L(convert_reorder_inner_end_label); + if (tail_size) { + convert_reorder(tail_size); + } + + dec(reg_outter_work_amount); + add(reg_in, jcp_.src_prc.size() * jcp_.src_stride); + add(reg_out, jcp_.dst_prc.size() * jcp_.dst_stride); + + jmp(convert_reorder_outter_loop_label, T_NEAR); + } + L(convert_reorder_outter_end_label); + + this->postamble(); + + for (const auto& emitter : emitters) { + if (emitter.second) + emitter.second->emit_data(); + } + } + + void convert_reorder(size_t step) { + bool is_tail = step < vec_size; + + load(vmm_in, reg_in_aux, jcp_.src_prc, step, is_tail); + + if (jcp_.with_scales) { + if (!jcp_.broadcast_scales) { + load(vmm_scales, reg_scales, Precision::FP32, step, is_tail); + add(reg_scales, sizeof(float) * step); + } + uni_vmulps(vmm_in, vmm_in, vmm_scales); + } + + store(reg_out_aux, vmm_in, jcp_.dst_prc, step); + + if (!is_tail) { + add(reg_in_aux, jcp_.src_prc.size() * step); + add(reg_out_aux, jcp_.dst_prc.size() * step); + } + } +#undef GET_OFF + + inline void load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, Precision src_prc, const int& elt_num, bool fill) { + const auto seed = load_emitter_params(src_prc, Precision::FP32, elt_num, fill, "float_min").hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, Precision::FP32, elt_num, Precision::FP32, fill, "float_min")); + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), 0}, {static_cast(vmm_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); + } + inline void store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, Precision dst_prc, const int& elt_num) { + const auto seed = store_emitter_params(Precision::FP32, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_store_emitter(this, isa, Precision::FP32, dst_prc, elt_num)); + } + + emitters[seed]->emit_code({static_cast(vmm_src.getIdx()), 0}, {static_cast(reg_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); + } + + size_t vec_size; + + Xmm xmm_tmp = Xmm(2); + Vmm vmm_scales = Vmm(0); + Vmm vmm_in = Vmm(1); + + Reg64 reg_in = r8; + Reg64 reg_in_aux = r9; + Reg64 reg_out = r10; + Reg64 reg_out_aux = r11; + Reg64 reg_scales = r12; + Reg64 reg_inner_work_amount = r14; + Reg64 reg_outter_work_amount = r15; + Reg64 reg_params = abi_param1; + + const std::vector pool_aux_gpr_idxs = { static_cast(rsi.getIdx()), static_cast(rbp.getIdx()) }; + const std::vector pool_aux_vmm_idxs = { static_cast(xmm_tmp.getIdx()) }; + + std::unordered_map> emitters; +}; + +template +struct jit_convert_transpose_kernel : public jit_uni_convert_transpose_kernel, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_convert_transpose_kernel) + + explicit jit_convert_transpose_kernel(const jit_convert_transpose_compile_params& jcp) : jit_uni_convert_transpose_kernel(jcp), jit_generator() { + interm_prc = jcp_.with_scales ? Precision(Precision::FP32) : jcp_.src_prc; + vec_size = dnnl::impl::cpu::x64::cpu_isa_traits::vlen / interm_prc.size(); + } + virtual ~jit_convert_transpose_kernel() {} + + void create_ker() override { + jit_generator::create_kernel(); + ker_ = (decltype(ker_))jit_ker(); + } + +private: + using Vmm = typename dnnl::impl::utils::conditional3::type; + + void generate() override { + this->preamble(); + +#define GET_OFF(field) offsetof(jit_convert_transpose_call_args, field) + mov(reg_in, ptr[reg_params + GET_OFF(p_in)]); + mov(reg_out, ptr[reg_params + GET_OFF(p_out)]); + if (jcp_.with_scales) { + mov(reg_scales, ptr[reg_params + GET_OFF(p_scales)]); + } + + Xbyak::Label convert_transpose_inner_loop_label; + Xbyak::Label convert_transpose_inner_end_label; + Xbyak::Label convert_transpose_outter_loop_label; + Xbyak::Label convert_transpose_outter_end_label; + + if (jcp_.with_scales && jcp_.broadcast_scales) { + uni_vmovss(Xmm(vmm_scales.getIdx()), ptr[reg_scales]); + uni_vbroadcastss(vmm_scales, Xmm(vmm_scales.getIdx())); + } + + mov(reg_outter_work_amount, jcp_.outter_work_amount); + L(convert_transpose_outter_loop_label); + { + cmp(reg_outter_work_amount, 1); + jl(convert_transpose_outter_end_label, T_NEAR); + + size_t tail_size = jcp_.inner_work_amount % vec_size; + mov(reg_inner_work_amount, jcp_.inner_work_amount); + mov(reg_in_aux, reg_in); + mov(reg_out_aux, reg_out); + if (jcp_.with_scales && !jcp_.broadcast_scales) { + mov(reg_scales, ptr[reg_params + GET_OFF(p_scales)]); + } + + L(convert_transpose_inner_loop_label); + { + cmp(reg_inner_work_amount, vec_size); + jl(convert_transpose_inner_end_label, T_NEAR); + + convert_transpose(vec_size); + + sub(reg_inner_work_amount, vec_size); + + jmp(convert_transpose_inner_loop_label, T_NEAR); + } + L(convert_transpose_inner_end_label); + if (tail_size) { + convert_transpose(tail_size); + } + + dec(reg_outter_work_amount); + add(reg_in, jcp_.src_prc.size() * jcp_.outter_src_stride); + add(reg_out, jcp_.dst_prc.size() * jcp_.outter_dst_stride); + + jmp(convert_transpose_outter_loop_label, T_NEAR); + } + L(convert_transpose_outter_end_label); + + this->postamble(); + + for (const auto& emitter : emitters) { + if (emitter.second) + emitter.second->emit_data(); + } + } + + void convert_transpose(size_t step) { + bool is_tail = step < vec_size; + + sub(rsp, jcp_.src_prc.size() * vec_size); + for (size_t i = 0; i < step; i++) { + if (jcp_.src_prc.size() == 4) { + mov(reg_tmp_32, ptr[reg_in_aux + i * jcp_.inner_src_stride * jcp_.src_prc.size()]); + mov(ptr[rsp + i * jcp_.src_prc.size()], reg_tmp_32); + } else if (jcp_.src_prc.size() == 2) { + mov(reg_tmp_16, ptr[reg_in_aux + i * jcp_.inner_src_stride * jcp_.src_prc.size()]); + mov(ptr[rsp + i * jcp_.src_prc.size()], reg_tmp_16); + } else if (jcp_.src_prc.size() == 1) { + mov(reg_tmp_8, ptr[reg_in_aux + i * jcp_.inner_src_stride * jcp_.src_prc.size()]); + mov(ptr[rsp + i * jcp_.src_prc.size()], reg_tmp_8); + } + } + load(vmm_in, rsp, jcp_.src_prc, interm_prc, vec_size, false); + add(rsp, jcp_.src_prc.size() * vec_size); + + if (jcp_.with_scales) { + if (!jcp_.broadcast_scales) { + load(vmm_scales, reg_scales, Precision::FP32, Precision::FP32, step, false); + add(reg_scales, sizeof(float) * step); + } + uni_vmulps(vmm_in, vmm_in, vmm_scales); + } + + store(reg_out_aux, vmm_in, interm_prc, jcp_.dst_prc, step); + + if (!is_tail) { + add(reg_in_aux, jcp_.src_prc.size() * step * jcp_.inner_src_stride); + add(reg_out_aux, jcp_.dst_prc.size() * step); + } + } +#undef GET_OFF + inline void load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, Precision src_prc, Precision dst_prc, const int& elt_num, bool fill) { + const auto seed = load_emitter_params(src_prc, dst_prc, elt_num, fill, "float_min").hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num, Precision::FP32, fill, "float_min")); + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), 0}, {static_cast(vmm_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); + } + inline void store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, Precision src_prc, Precision dst_prc, const int& elt_num) { + const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num)); + } + + emitters[seed]->emit_code({static_cast(vmm_src.getIdx()), 0}, {static_cast(reg_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); + } + + size_t vec_size; + Precision interm_prc; + + Xmm xmm_tmp = Xmm(2); + Vmm vmm_scales = Vmm(0); + Vmm vmm_in = Vmm(1); + + Reg64 reg_in = r8; + Reg64 reg_in_aux = r9; + Reg64 reg_out = r10; + Reg64 reg_out_aux = r11; + Reg64 reg_scales = r12; + Reg8 reg_tmp_8 = Reg8(r13.getIdx()); + Reg16 reg_tmp_16 = Reg16(r13.getIdx()); + Reg32 reg_tmp_32 = Reg32(r13.getIdx()); + Reg64 reg_inner_work_amount = r14; + Reg64 reg_outter_work_amount = r15; + Reg64 reg_params = abi_param1; + + const std::vector pool_aux_gpr_idxs = { static_cast(rsi.getIdx()), static_cast(rbp.getIdx()) }; + const std::vector pool_aux_vmm_idxs = { static_cast(xmm_tmp.getIdx()) }; + + std::unordered_map> emitters; +}; + +bool MHA::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + const auto mha = std::dynamic_pointer_cast(op); + if (!mha) { + errorMessage = "Only MHA from CPU internal opset is supported"; + return false; + } + + if (isDynamicNgraphNode(op)) { + errorMessage = "Doesn't support op with dynamic shapes"; + return false; + } + + bool supportedPrecisions = true; + if (!(mha->get_input_element_type(0) == element::i8 && + mha->get_input_element_type(1) == element::f32 && + mha->get_input_element_type(3) == element::f32)) { + if (!one_of(mha->get_input_element_type(0), element::f32, element::bf16, element::i8)) { + supportedPrecisions = false; + } + + if (mha->get_input_element_type(0) != mha->get_input_element_type(1) || + mha->get_input_element_type(0) != mha->get_input_element_type(3)) { + supportedPrecisions = false; + } + } else { + if (mha->get_fq0_output_type() != mha->get_input_element_type(0)) + supportedPrecisions = false; + } + + if (!mha->get_fq_scales1().empty() && mha->get_fq1_output_type() != element::i8) { + supportedPrecisions = false; + } + + if (mha->get_input_element_type(3) == element::i8) { + if (!one_of(mha->get_fq2_output_type(), element::u8, element::i8)) { + supportedPrecisions = false; + } + } + + if (!supportedPrecisions) { + errorMessage = "Doesn't support provided input precisions"; + return false; + } + + if (!one_of(mha->get_output_element_type(0), element::f32, element::bf16, element::i8, element::u8)) { + errorMessage = "Doesn't support provided output precision"; + return false; + } + + if (mha->get_input_element_type(0) == element::f32 && !mayiuse(avx512_core)) { + errorMessage = "Doesn't support f32 execution precision on targets w/o avx512_core support"; + return false; + } + + if (mha->get_input_element_type(0) == element::bf16 && !mayiuse(avx512_core_bf16)) { + errorMessage = "Doesn't support bf16 execution precision on targets w/o avx512_core_bf16 support"; + return false; + } + + if (mha->get_input_element_type(0) == element::i8 && !mayiuse(avx512_core_vnni)) { + errorMessage = "Doesn't support i8 execution precision on targets w/o avx512_core_vnni support"; + return false; + } + + if (mha->get_input_shape(0).size() != 4) { + errorMessage = "Doesn't support inputs with rank != 4"; + return false; + } + } catch (...) { + return false; + } + + return true; +} + +MHA::MHA(const std::shared_ptr& op, const dnnl::engine& eng, + WeightsSharing::Ptr &cache) : Node(op, eng, cache) { + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + IE_THROW(NotImplemented) << errorMessage; + } + + const auto mha = std::dynamic_pointer_cast(op); + mulScales = mha->get_mul_scales(); + isMulFirst = mha->get_is_mul_first(); + fqScales0 = mha->get_fq_scales0(); + fqScales1 = mha->get_fq_scales1(); + fqScales2 = mha->get_fq_scales2(); + fqScales3 = mha->get_fq_scales3(); + fqPrc2 = details::convertPrecision(mha->get_fq2_output_type()); +} + +void MHA::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + + for (auto idx : {0, 1, 2, 3}) { + inputPrecisions.push_back(getOriginalInputPrecisionAtPort(idx)); + if (!one_of(inputPrecisions[idx], Precision::FP32, Precision::BF16, Precision::I8)) + THROW_ERROR << "doesn't support " << inputPrecisions[idx].name() << " precision on " << idx << " input port"; + } + + if ((inputPrecisions[0] != inputPrecisions[1]) && + !(inputPrecisions[0] == Precision::I8 && inputPrecisions[1] == Precision::FP32 && !fqScales0.empty())) { + inputPrecisions[0] = inputPrecisions[1] = Precision::FP32; + } + + inputPrecisions[2] = Precision::FP32; + + if (inputPrecisions[3] == Precision::I8 && fqScales2.empty()) + inputPrecisions[3] = Precision::FP32; + + + if (!one_of(getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16, Precision::I8, Precision::U8)) + THROW_ERROR << "doesn't support " << getOriginalOutputPrecisionAtPort(0).name() << " precision on output port"; + + addSupportedPrimDesc({{LayoutType::ncsp, inputPrecisions[0]}, + {LayoutType::ncsp, inputPrecisions[1]}, + {LayoutType::ncsp, Precision::FP32}, + {LayoutType::ncsp, inputPrecisions[3]}}, + {{LayoutType::ncsp, getOriginalOutputPrecisionAtPort(0)}}, + ref_any, + isDynamicNode()); +} + +void MHA::init_brgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) { + brgemm_t brgDesc; + brgemm_strides_t strides {static_cast(ctx.M * ctx.K), static_cast(ctx.K * ctx.N)}; + + auto isa = use_amx ? isa_any + : ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni; + auto status = brgemm_desc_init(&brgDesc, isa, brgemm_strd, ctx.dt_in0, ctx.dt_in1, + false, false, brgemm_row_major, 1.f, ctx.beta, ctx.LDA, ctx.LDB, ctx.LDC, ctx.M, ctx.N, ctx.K, &strides); + if (status != dnnl_success) { + THROW_ERROR << "cannot be executed due to invalid brgconv params"; + } + + ctx.is_with_amx = use_amx; + status = brgemm_init_tiles(brgDesc, ctx.palette); + if (use_amx) { + amx_tile_configure(ctx.palette); + } + + ctx.is_with_comp = ctx.dt_in0 == dnnl_data_type_t::dnnl_s8 && !ctx.is_with_amx; + + brgemm_kernel_t* brgKernel_ = nullptr; + status = brgemm_kernel_create(&brgKernel_, brgDesc); + if (status != dnnl_success) { + THROW_ERROR << "cannot be executed due to invalid brgconv params"; + } + brgKernel.reset(brgKernel_); +} + +void MHA::init_brgemm_copy_a(std::unique_ptr& brgCopyKernel, size_t K, size_t K_blk, size_t K_tail, + size_t LDA, dnnl_data_type_t dt_in0) { + brgemm_matmul_conf_t brgCopyKernelConf; + brgCopyKernelConf.src_tag = dnnl_abcd; + brgCopyKernelConf.K = K; + brgCopyKernelConf.K_tail = K_tail; + brgCopyKernelConf.K_blk = K_blk; + brgCopyKernelConf.use_buffer_a_tail_only = false; + brgCopyKernelConf.LDA = false; + brgCopyKernelConf.has_zero_point_b = false; + brgCopyKernelConf.s8s8_compensation_required = false; + brgCopyKernelConf.wei_zp_type = dnnl::impl::cpu::x64::none; + brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none; + brgCopyKernelConf.src_dt = dt_in0; + brgCopyKernelConf.a_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(dt_in0)); + brgCopyKernelConf.transposed_A = false; + + create_brgemm_matmul_copy_a(brgCopyKernel, &brgCopyKernelConf); +} + +void MHA::init_brgemm_copy_b(std::unique_ptr& brgCopyKernel, size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, + bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) { + brgemm_matmul_conf_t brgCopyKernelConf; + brgCopyKernelConf.src_dt = dt_in0; + brgCopyKernelConf.wei_dt = dt_in1; + brgCopyKernelConf.wei_n_blk = N_blk; + brgCopyKernelConf.wei_tag = dnnl_abcd; + brgCopyKernelConf.copy_B_wei_stride = 0; + brgCopyKernelConf.LDB = LDB; + brgCopyKernelConf.N = N; + brgCopyKernelConf.N_tail = N_tail; + brgCopyKernelConf.N_blk = N_blk; + brgCopyKernelConf.K = K; + brgCopyKernelConf.K_blk = K; + brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; + + if (is_with_amx) { + brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16_amx_bf16 : avx512_core_bf16_amx_int8; + brgCopyKernelConf.s8s8_compensation_required = false; + } else { + brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni; + brgCopyKernelConf.s8s8_compensation_required = dt_in0 == dnnl_data_type_t::dnnl_s8; + } + + brgCopyKernelConf.has_zero_point_a = false; + brgCopyKernelConf.has_zero_point_b = false; + brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none; + + create_brgemm_matmul_copy_b(brgCopyKernel, &brgCopyKernelConf); +} + +void MHA::prepareParams() { + auto transpose = [](const std::vector& vec, const std::vector& order) -> std::vector { + std::vector new_vec(vec.size()); + for (int i = 0; i < vec.size(); i++) { + new_vec[i] = vec[order[i]]; + } + return new_vec; + }; + + const auto memDescTranspose0In0 = getParentEdgeAt(0)->getMemoryPtr()->GetDescWithType(); + const auto memDescTranspose1In0 = getParentEdgeAt(1)->getMemoryPtr()->GetDescWithType(); + const auto memDescAddIn1 = getParentEdgeAt(2)->getMemoryPtr()->GetDescWithType(); + const auto memDescTranspose2In0 = getParentEdgeAt(3)->getMemoryPtr()->GetDescWithType(); + const auto memDescOut = getChildEdgeAt(0)->getMemoryPtr()->GetDescWithType(); + + dimsTranspose0In0 = memDescTranspose0In0->getBlockDims(); + dimsTranspose1In0 = memDescTranspose1In0->getBlockDims(); + dimsAddIn1 = memDescAddIn1->getBlockDims(); + dimsTranspose2In0 = memDescTranspose2In0->getBlockDims(); + dimsOut = memDescOut->getBlockDims(); + + strTranspose0In0 = memDescTranspose0In0->getStrides(); + strTranspose1In0 = memDescTranspose1In0->getStrides(); + strAddIn1 = memDescAddIn1->getStrides(); + strTranspose2In0 = memDescTranspose2In0->getStrides(); + strOut = memDescOut->getStrides(); + + std::vector orderTranspose0 = {0, 2, 1, 3}; + dimsMatMul0In0 = transpose(dimsTranspose0In0, orderTranspose0); + + std::vector orderTranspose1 = {0, 2, 3, 1}; + dimsMatMul0In1 = transpose(dimsTranspose1In0, orderTranspose1); + + dimsMatMul0Out = {dimsMatMul0In0[0], dimsMatMul0In0[1], dimsMatMul0In0[2], dimsMatMul0In1[3]}; + + std::vector orderTranspose2 = {0, 2, 1, 3}; + dimsMatMul1In1 = transpose(dimsTranspose2In0, orderTranspose2); + + bool isAMXSupported = mayiuse(avx512_core_bf16_amx_int8) || mayiuse(avx512_core_bf16_amx_bf16); + + size_t numThreads = parallel_get_max_threads(); + + size_t matmulOptimalM = 32; + + batch0 = dimsMatMul0Out[0]; + batch1 = dimsMatMul0Out[1]; + + M = dimsMatMul0In0[2]; + M_blk = matmulOptimalM; + M_tail = M % M_blk; + + N0 = dimsMatMul0In1[3]; + K0 = dimsMatMul0In0[3]; + + auto brg0Prc = inputPrecisions[0]; + brg0VnniFactor = 4 / brg0Prc.size(); + bool brg0WithAMX = isAMXSupported && brg0Prc != Precision::FP32 && (K0 % brg0VnniFactor == 0) && (N0 % brg0VnniFactor == 0); + + N0_blk = brg0Prc == Precision::FP32 ? N0 : + brg0Prc == Precision::BF16 ? 32 : 64; + N0_tail = N0 % N0_blk; + K0_blk = brg0WithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 + : K0; + K0_tail = K0 % K0_blk; + + accPrecision0 = brg0Prc == Precision::I8 ? Precision::I32 : Precision::FP32; + + size_t brg0BaseIdx = -1; + for (size_t m = 0; m < 2; m++) { + for (size_t k = 0; k < 2; k++) { + for (size_t n = 0; n < 2; n++) { + auto& brgemmCtx = brgCtxs0[getBrgIdx(m, k, n)]; + + auto M_ = m ? M_tail + : M < M_blk ? 0 : M_blk; + auto N_ = n ? N0_tail : N0 - N0_tail; + auto K_ = k ? K0_tail : K0 - K0_tail; + auto beta = k && brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; + + brgemmCtx.M = M_; + brgemmCtx.N = N_; + brgemmCtx.K = K_; + brgemmCtx.LDA = batch1 * K0; + brgemmCtx.LDB = rnd_up(N0, N0_blk); + brgemmCtx.LDC = N0; + brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc)); + brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc)); + brgemmCtx.beta = beta; + + // don't create brgemm kernels for empty tiles + if (M_ != 0 && K_ != 0 && N_ != 0) { + if (brg0BaseIdx == -1) + brg0BaseIdx = getBrgIdx(m, k, n); + init_brgemm(brgemmCtx, brgKernels0[getBrgIdx(m, k, n)], brg0WithAMX); + } + } + } + } + + auto& brgemmCtx0 = brgCtxs0[brg0BaseIdx]; + + // TODO: matrix A copy should be performed to enable AMX matmuls for arbitrary shapes + // if (brgemmCtx0.is_with_amx && K0_tail) { + // init_brgemm_copy_a(brgCopyAKernel0, K0, K0_blk, K0_tail, brgemmCtx0.LDA, brgemmCtx0.dt_in0); + // } + + if (brgemmCtx0.is_with_amx || brg0Prc == Precision::I8 || brg0Prc == Precision::BF16) { + init_brgemm_copy_b(brgCopyBKernel0, N0, N0_blk, N0_tail, brgemmCtx0.LDB, brgemmCtx0.K, + brgemmCtx0.is_with_amx, brgemmCtx0.dt_in0, brgemmCtx0.dt_in1); + } + + dimsMatMul1Out = {dimsMatMul0Out[0], dimsMatMul0Out[1], dimsMatMul0Out[2], dimsMatMul1In1[3]}; + + N1 = dimsMatMul1Out[3]; + K1 = dimsMatMul0Out[3]; + + auto brg1PrcIn0 = !fqScales2.empty() ? fqPrc2 : inputPrecisions[3]; + auto brg1PrcIn1 = inputPrecisions[3]; + brg1VnniFactor = 4 / brg1PrcIn0.size(); + bool brg1WithAMX = isAMXSupported && brg1PrcIn0 != Precision::FP32 && (K1 % brg1VnniFactor == 0) && (N1 % brg1VnniFactor == 0); + + N1_blk = brg1PrcIn1 == Precision::FP32 ? N1 : + brg1PrcIn1 == Precision::BF16 ? 32 : 64; + N1_tail = N1 % N1_blk; + K1_blk = brg1WithAMX ? brg1PrcIn0 == Precision::BF16 ? 32 : 64 + : K1; + K1_tail = K1 % K1_blk; + + accPrecision1 = one_of(brg1PrcIn0, Precision::U8, Precision::I8) ? Precision::I32 : Precision::FP32; + + size_t brg1BaseIdx = -1; + for (size_t m = 0; m < 2; m++) { + for (size_t k = 0; k < 2; k++) { + for (size_t n = 0; n < 2; n++) { + auto& brgemmCtx = brgCtxs1[getBrgIdx(m, k, n)]; + + auto M_ = m ? M_tail + : M < M_blk ? 0 : M_blk; + auto N_ = n ? N1_tail : N1 - N1_tail; + auto K_ = k ? K1_tail : K1 - K1_tail; + + auto beta = k && brgCtxs1[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; + brgemmCtx.M = M_; + brgemmCtx.N = N_; + brgemmCtx.K = K_; + brgemmCtx.LDA = K1; + brgemmCtx.LDB = brg1PrcIn1 == Precision::FP32 ? batch1 * N1 : rnd_up(N1, N1_blk); + brgemmCtx.LDC = accPrecision1 == getOriginalOutputPrecisionAtPort(0) ? batch1 * N1 : N1; + brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg1PrcIn0)); + brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg1PrcIn1)); + brgemmCtx.beta = beta; + + // don't create brgemm kernels for empty tiles + if (M_ != 0 && K_ != 0 && N_ != 0) { + if (brg1BaseIdx == -1) + brg1BaseIdx = getBrgIdx(m, k, n); + + init_brgemm(brgemmCtx, brgKernels1[getBrgIdx(m, k, n)], brg1WithAMX); + } + } + } + } + + auto& brgemmCtx1 = brgCtxs1[brg1BaseIdx]; + if (brgemmCtx1.is_with_amx || brg1PrcIn1 == Precision::I8 || brg1PrcIn1 == Precision::BF16) { + init_brgemm_copy_b(brgCopyBKernel1, batch1 * N1, N1_blk, N1_tail, brgemmCtx1.LDB, brgemmCtx1.K, + brgemmCtx1.is_with_amx, brgemmCtx1.dt_in0, brgemmCtx1.dt_in1); + } + + bufferMatMul0In0Size = M_blk * rnd_up(K0, K0_blk) * brg0Prc.size(); + bufferMatMul0In1Size = rnd_up(K0, brg0VnniFactor) * rnd_up(N0, N0_blk) * brg0Prc.size(); + bufferMatMul0OutSize = brgemmCtx0.M * N0 * accPrecision0.size(); + bufferMatMul1In1Size = rnd_up(K1, brg1VnniFactor) * rnd_up(N1, N1_blk) * std::max(brg0Prc.size(), brg1PrcIn1.size()); + bufferMatMul1OutSize = brgemmCtx1.M * N1 * accPrecision1.size(); + bufferCompensation0Size = rnd_up(N0, N0_blk); + bufferCompensation1Size = rnd_up(N1, N1_blk); + + if (brgCopyAKernel0) { + bufferMatMul0In0.resize(numThreads * bufferMatMul0In0Size); + } + bufferMatMul0In1.resize(numThreads * bufferMatMul0In1Size); + bufferMatMul0Out.resize(numThreads * bufferMatMul0OutSize); + bufferMatMul1In1.resize(numThreads * bufferMatMul1In1Size); + bufferMatMul1Out.resize(numThreads * bufferMatMul1OutSize); + if (brgemmCtx0.is_with_comp) { + bufferCompensation0.resize(numThreads * bufferCompensation0Size); + } + if (brgemmCtx1.is_with_comp) { + bufferCompensation1.resize(numThreads * bufferCompensation1Size); + } + + if (brgemmCtx0.is_with_amx || brgemmCtx1.is_with_amx) { + wsp.resize(numThreads * wsp_size_per_thread); + } + + { + jit_mul_add_softmax_compile_params jcp; + jcp.src_prc = accPrecision0; + jcp.dst_prc = brg1PrcIn0; + jcp.work_amount = N0; + jcp.with_mul_scales = !mulScales.empty(); + jcp.is_mul_first = isMulFirst; + jcp.with_scales0 = !fqScales1.empty(); + jcp.broadcast_scales0 = fqScales1.size() == 1; + jcp.with_scales1 = !fqScales2.empty(); + jcp.broadcast_scales1 = fqScales2.size() == 1; + + if (mayiuse(cpu_isa_t::avx512_core)) { + mulAddSoftmaxKernel.reset(new jit_mul_add_softmax_kernel(jcp)); + } else if (mayiuse(cpu_isa_t::avx2)) { + mulAddSoftmaxKernel.reset(new jit_mul_add_softmax_kernel(jcp)); + } else if (mayiuse(cpu_isa_t::sse41)) { + mulAddSoftmaxKernel.reset(new jit_mul_add_softmax_kernel(jcp)); + } else { + THROW_ERROR << "cannot create jit eltwise kernel"; + } + } + + if (accPrecision1 != getOriginalOutputPrecisionAtPort(0)) { + jit_convert_reorder_compile_params jcp; + jcp.src_prc = accPrecision1; + jcp.dst_prc = getOriginalOutputPrecisionAtPort(0); + jcp.inner_work_amount = N1; + jcp.with_scales = !fqScales3.empty(); + jcp.broadcast_scales = fqScales3.size() == 1; + jcp.src_stride = N1; + jcp.dst_stride = batch1 * N1; + + if (mayiuse(cpu_isa_t::avx512_core)) { + convertReorderKernel.reset(new jit_convert_reorder_kernel(jcp)); + } else if (mayiuse(cpu_isa_t::avx2)) { + convertReorderKernel.reset(new jit_convert_reorder_kernel(jcp)); + } else if (mayiuse(cpu_isa_t::sse41)) { + convertReorderKernel.reset(new jit_convert_reorder_kernel(jcp)); + } else { + THROW_ERROR << "cannot create jit eltwise kernel"; + } + } + + if (!fqScales0.empty() || inputPrecisions[1] != brg0Prc) { + jit_convert_transpose_compile_params jcp; + jcp.src_prc = inputPrecisions[1]; + jcp.dst_prc = brg0Prc; + jcp.inner_work_amount = N0; + jcp.outter_work_amount = K0; + jcp.with_scales = !fqScales0.empty(); + jcp.broadcast_scales = fqScales0.size() == 1; + jcp.inner_src_stride = strTranspose1In0[1]; + jcp.outter_src_stride = strTranspose1In0[3]; + jcp.outter_dst_stride = N0; + + if (mayiuse(cpu_isa_t::avx512_core)) { + convertTransposeKernel.reset(new jit_convert_transpose_kernel(jcp)); + } else if (mayiuse(cpu_isa_t::avx2)) { + convertTransposeKernel.reset(new jit_convert_transpose_kernel(jcp)); + } else if (mayiuse(cpu_isa_t::sse41)) { + convertTransposeKernel.reset(new jit_convert_transpose_kernel(jcp)); + } else { + THROW_ERROR << "cannot create jit eltwise kernel"; + } + } + + if (mulAddSoftmaxKernel) + mulAddSoftmaxKernel->create_ker(); + + if (convertReorderKernel) + convertReorderKernel->create_ker(); + + if (convertTransposeKernel) + convertTransposeKernel->create_ker(); + + const auto& selectedPD = getSelectedPrimitiveDescriptor(); + if (brgemmCtx0.is_with_amx || brgemmCtx1.is_with_amx) { + selectedPD->setImplementationType(jit_avx512_amx); + } else { + if (mayiuse(cpu_isa_t::avx512_core)) { + selectedPD->setImplementationType(jit_avx512); + } else if (mayiuse(cpu_isa_t::avx2)) { + selectedPD->setImplementationType(jit_avx2); + } else if (mayiuse(cpu_isa_t::sse41)) { + selectedPD->setImplementationType(jit_sse42); + } + } +} + +template +static void reorder2D(const srcT* pin, dstT* pout, const std::vector& dimsOut, + const std::vector& stridesOut, const std::vector& stridesIn) { + for (int i0 = 0; i0 < dimsOut[0]; i0++) { + for (int i1 = 0; i1 < dimsOut[1]; i1++) { + pout[i0 * stridesOut[0] + i1 * stridesOut[1]] = static_cast(pin[i0 * stridesIn[0] + i1 * stridesIn[1]]); + } + } +} + +void MHA::callBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, const void* pin0, const void* pin1, void* pout, void* wsp) { + if (ctx.is_with_amx) + amx_tile_configure(ctx.palette); + if (ctx.is_with_comp) { + brgemm_post_ops_data_t post_ops_data; + brgemm_kernel_execute_postops(brgKernel.get(), 1, pin0, pin1, nullptr, pout, pout, post_ops_data, wsp); + } else { + brgemm_kernel_execute(brgKernel.get(), 1, pin0, pin1, nullptr, pout, wsp); + } +} + +template +void MHA::mhaImpl() { + const uint8_t* pTranspose0In0 = reinterpret_cast(getParentEdgeAt(0)->getMemoryPtr()->GetPtr()); + const uint8_t* pTranspose1In0 = reinterpret_cast(getParentEdgeAt(1)->getMemoryPtr()->GetPtr()); + const float* pAddIn1 = reinterpret_cast(getParentEdgeAt(2)->getMemoryPtr()->GetPtr()); + const uint8_t* pTranspose2In0 = reinterpret_cast(getParentEdgeAt(3)->getMemoryPtr()->GetPtr()); + uint8_t* pout = reinterpret_cast(getChildEdgeAt(0)->getMemoryPtr()->GetPtr()); + + auto outPrcSize = getOriginalOutputPrecisionAtPort(0).size(); + + parallel_for2d(dimsMatMul0Out[0], dimsMatMul0Out[1], [&](size_t i0, size_t i1) { + size_t threadNum = parallel_get_thread_num(); + + auto pTranspose0In0_aux = pTranspose0In0 + (i0 * strTranspose0In0[0] + i1 * strTranspose0In0[2]) * inputPrecisions[0].size(); // order 0213 + auto pTranspose1In0_aux = pTranspose1In0 + (i0 * strTranspose1In0[0] + i1 * strTranspose1In0[2]) * inputPrecisions[1].size(); // order 0231 + + auto pAddIn1_aux = pAddIn1 + i0 * strAddIn1[0]; // order 0231 + + auto bufferMatMul0In1_local = reinterpret_cast(bufferMatMul0In1.data() + threadNum * bufferMatMul0In1Size); + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.data() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1In1_local = reinterpret_cast(bufferMatMul1In1.data() + threadNum * bufferMatMul1In1Size); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.data() + threadNum * bufferMatMul1OutSize); + + auto pTranspose1Out_aux = brgCopyBKernel0 ? bufferMatMul1In1_local + : bufferMatMul0In1_local; + auto pTranspose2In0_aux = pTranspose2In0 + (i0 * strTranspose2In0[0] + i1 * strTranspose2In0[2]) * inputPrecisions[3].size(); // order 0213 + + if (convertTransposeKernel) { + jit_convert_transpose_call_args call_args; + call_args.p_in = pTranspose1In0_aux; + call_args.p_out = pTranspose1Out_aux; + call_args.p_scales = fqScales0.data(); + + (*convertTransposeKernel)(&call_args); + } else { + reorder2D(reinterpret_cast(pTranspose1In0_aux), reinterpret_cast(pTranspose1Out_aux), {K0, N0}, {N0, 1}, + {strTranspose1In0[3], strTranspose1In0[1]}); + } + + auto bufferCompensation0_aux = !bufferCompensation0.empty() + ? bufferCompensation0.data() + threadNum * bufferCompensation0Size + : nullptr; + auto bufferCompensation1_aux = !bufferCompensation1.empty() + ? bufferCompensation1.data() + threadNum * bufferCompensation1Size + : nullptr; + + auto wsp_local = !wsp.empty() ? wsp.data() + threadNum * wsp_size_per_thread : nullptr; + + auto pMatMul0In1 = reinterpret_cast(pTranspose1Out_aux); + if (brgCopyBKernel0) { + for (size_t nb = 0; nb < div_up(N0, N0_blk); nb++) { + auto pCopyKernel0In = pMatMul0In1 + nb * N0_blk * inputPrecisions[0].size(); + auto pCopyKernel0Out = bufferMatMul0In1_local + nb * N0_blk * brg0VnniFactor * inputPrecisions[0].size(); + + auto ctx = jit_brgemm_matmul_copy_b_t::ctx_t(); + + const bool is_N_tail = (N0 - nb * N0_blk < N0_blk); + ctx.current_N_blk = is_N_tail ? N0_tail : N0_blk; + ctx.src = pCopyKernel0In; + ctx.tr_src = pCopyKernel0Out; + ctx.compensation_ptr = bufferCompensation0_aux + nb * N0_blk; + ctx.zp_a_compensation_ptr = nullptr; + ctx.zp_a_neg_value_ptr = nullptr; + ctx.current_K_start = 0; + ctx.current_K_iters = K0; + + (*brgCopyBKernel0)(&ctx); + } + + pMatMul0In1 = bufferMatMul0In1_local; + } + + auto pMatMul1In1 = pTranspose2In0_aux; + if (brgCopyBKernel1) { + for (size_t nb = 0; nb < div_up(N1, N1_blk); nb++) { + auto pCopyKernel1In = pMatMul1In1 + nb * N1_blk * inputPrecisions[3].size(); + auto pCopyKernel1Out = reinterpret_cast(bufferMatMul1In1_local) + nb * N1_blk * brg1VnniFactor * inputPrecisions[3].size(); + + auto ctx = jit_brgemm_matmul_copy_b_t::ctx_t(); + + const bool is_N_tail = (N1 - nb * N1_blk < N1_blk); + ctx.current_N_blk = is_N_tail ? N1_tail : N1_blk; + ctx.src = pCopyKernel1In; + ctx.tr_src = pCopyKernel1Out; + ctx.compensation_ptr = bufferCompensation1_aux + nb * N1_blk; + ctx.zp_a_compensation_ptr = nullptr; + ctx.zp_a_neg_value_ptr = nullptr; + ctx.current_K_start = 0; + ctx.current_K_iters = K1; + + (*brgCopyBKernel1)(&ctx); + } + + pMatMul1In1 = reinterpret_cast(bufferMatMul1In1_local); + } + + for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { + const bool is_M_tail = (M - mb * M_blk < M_blk); + auto cur_M_blk = is_M_tail ? M_tail : M_blk; + + auto pMatMul0In0 = pTranspose0In0_aux + (mb * M_blk * batch1 * K0) * inputPrecisions[0].size(); + + // TODO: matrix A copy should be performed to enable AMX matmuls for arbitrary shapes + // if (brgCopyAKernel0) { + // auto bufferMatMul0In0_local = reinterpret_cast(bufferMatMul0In0.data() + threadNum * bufferMatMul0In0Size); + + // auto pCopyKernel0In = pMatMul0In0; + // auto pCopyKernel0Out = reinterpret_cast(bufferMatMul0In0_local); + + // auto ctx = jit_brgemm_matmul_copy_a_t::ctx_t(); + + // ctx.current_M_blk = cur_M_blk; + // ctx.zp_b_compensation_buffer_ptr = nullptr; + // ctx.zp_a_compensation_result_ptr = nullptr; + // ctx.zp_b_neg_value_ptr = nullptr; + // ctx.zp_ab_comp_ptr = nullptr; + // ctx.src = pCopyKernel0In; + // ctx.tr_src = pCopyKernel0Out; + // ctx.current_K_start = 0; + // ctx.current_K_blk = K0; + + // (*brgCopyAKernel0)(&ctx); + + // pMatMul0In0 = reinterpret_cast(bufferMatMul0In0_local); + // } + + auto pMatMul0Out = bufferMatMul0Out_local; + + size_t brgIdx0 = getBrgIdx(0, 0, 0); + size_t K0_step0 = brgCtxs0[brgIdx0].K; + size_t K0_step1 = brgCtxs0[brgIdx0].K * brgCtxs0[brgIdx0].LDB; + size_t N0_step0 = brgCtxs0[brgIdx0].N * brg0VnniFactor; + size_t N0_step1 = brgCtxs0[brgIdx0].N; + for (size_t n = 0; n < 2; n++) { + for (size_t k = 0; k < 2; k++) { + size_t mIdx = is_M_tail ? 1 : 0; + auto& brgemmCtx = brgCtxs0[getBrgIdx(mIdx, k, n)]; + + auto wsp = brgemmCtx.is_with_comp + ? reinterpret_cast(bufferCompensation0_aux + n * N0_step1) + : reinterpret_cast(wsp_local); + + if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { + callBrgemm(brgemmCtx, brgKernels0[getBrgIdx(mIdx, k, n)], + pMatMul0In0 + (k * K0_step0) * inputPrecisions[0].size(), pMatMul0In1 + (k * K0_step1 + n * N0_step0) * inputPrecisions[0].size(), + pMatMul0Out + (n * N0_step1) * accPrecision0.size(), wsp); + } + } + } + + auto pMulIn1 = reinterpret_cast(mulScales.empty() ? nullptr : mulScales.data()); + for (size_t m = 0; m < cur_M_blk; m++) { + jit_mul_add_softmax_call_args call_args; + call_args.p_in0 = pMatMul0Out + m * N0 * accPrecision0.size(); + call_args.p_mul_in1 = mulScales.size() > 1 ? pMulIn1 + i1 : pMulIn1; + call_args.p_add_in1 = pAddIn1_aux; + call_args.p_out = pMatMul0Out + m * N0 * inputPrecisions[3].size(); + call_args.p_buffer = pMatMul0Out + m * N0 * accPrecision0.size(); + call_args.p_scales0 = fqScales1.data(); + call_args.p_scales1 = fqScales2.data(); + + (*mulAddSoftmaxKernel)(&call_args); + } + + auto pMatMul1In0 = bufferMatMul0Out_local; + auto pOut_aux = pout + (i0 * strOut[0] + i1 * strOut[2]) * outPrcSize; + + auto pMatMul1Out = getOriginalOutputPrecisionAtPort(0) == Precision::FP32 + ? pOut_aux + (mb * M_blk * batch1 * N1) * outPrcSize + : bufferMatMul1Out_local; + + size_t brgIdx1 = getBrgIdx(0, 0, 0); + size_t K1_step0 = brgCtxs1[brgIdx1].K; + size_t K1_step1 = brgCtxs1[brgIdx1].K * brgCtxs1[brgIdx1].LDB; + size_t N1_step0 = brgCtxs1[brgIdx1].N * brg1VnniFactor; + size_t N1_step1 = brgCtxs1[brgIdx1].N; + for (size_t n = 0; n < 2; n++) { + for (size_t k = 0; k < 2; k++) { + size_t mIdx = is_M_tail ? 1 : 0; + auto& brgemmCtx = brgCtxs1[getBrgIdx(mIdx, k, n)]; + + auto wsp = brgemmCtx.is_with_comp + ? reinterpret_cast(bufferCompensation1_aux + n * N1_step1) + : reinterpret_cast(wsp_local); + + if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { + callBrgemm(brgemmCtx, brgKernels1[getBrgIdx(mIdx, k, n)], + pMatMul1In0 + (k * K1_step0) * inputPrecisions[3].size(), pMatMul1In1 + (k * K1_step1 + n * N1_step0) * inputPrecisions[3].size(), + pMatMul1Out + (n * N1_step1) * accPrecision1.size(), wsp); + } + } + } + + if (convertReorderKernel) { + jit_convert_reorder_call_args call_args; + call_args.p_in = pMatMul1Out; + call_args.p_out = pOut_aux + (mb * M_blk * batch1 * N1) * outPrcSize; + call_args.p_scales = fqScales3.data(); + call_args.outter_work_amount = cur_M_blk; + + (*convertReorderKernel)(&call_args); + } + } + }); +} + +void MHA::execute(dnnl::stream strm) { + if (inputPrecisions[1] == Precision::FP32) { + mhaImpl(); + } else if (inputPrecisions[1] == Precision::BF16) { + mhaImpl(); + } else if (inputPrecisions[1] == Precision::I8) { + mhaImpl(); + } else { + THROW_ERROR << "doesn't support provided input precisions"; + } +} + +void MHA::executeDynamicImpl(dnnl::stream strm) { + execute(strm); +} + +bool MHA::created() const { + return getType() == Type::MHA; +} + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/mha.h b/src/plugins/intel_cpu/src/nodes/mha.h new file mode 100644 index 00000000000..b442728f777 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/mha.h @@ -0,0 +1,243 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace intel_cpu { +namespace node { + +struct jit_mul_add_softmax_compile_params { + InferenceEngine::Precision src_prc; + InferenceEngine::Precision dst_prc; + size_t work_amount; + bool with_mul_scales; + bool is_mul_first; + bool with_scales0; + bool broadcast_scales0; + bool with_scales1; + bool broadcast_scales1; +}; + +struct jit_mul_add_softmax_call_args { + const void *p_in0; + const void *p_mul_in1; + const void *p_add_in1; + void *p_out; + void *p_buffer; + const void *p_scales0; + const void *p_scales1; +}; + +struct jit_uni_mul_add_softmax_kernel { + void (*ker_)(const jit_mul_add_softmax_call_args*); + + void operator()(const jit_mul_add_softmax_call_args* call_args) { + assert(ker_); + ker_(call_args); + } + + explicit jit_uni_mul_add_softmax_kernel(const jit_mul_add_softmax_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {} + virtual ~jit_uni_mul_add_softmax_kernel() {} + + virtual void create_ker() = 0; + + jit_mul_add_softmax_compile_params jcp_; +}; + +struct jit_convert_reorder_compile_params { + InferenceEngine::Precision src_prc; + InferenceEngine::Precision dst_prc; + size_t inner_work_amount; + bool with_scales; + bool broadcast_scales; + size_t src_stride; + size_t dst_stride; +}; + +struct jit_convert_reorder_call_args { + const void *p_in; + void *p_out; + const void *p_scales; + size_t outter_work_amount; +}; + +struct jit_uni_convert_reorder_kernel { + void (*ker_)(const jit_convert_reorder_call_args*); + + void operator()(const jit_convert_reorder_call_args* call_args) { + assert(ker_); + ker_(call_args); + } + + explicit jit_uni_convert_reorder_kernel(const jit_convert_reorder_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {} + virtual ~jit_uni_convert_reorder_kernel() {} + + virtual void create_ker() = 0; + + jit_convert_reorder_compile_params jcp_; +}; + +struct jit_convert_transpose_compile_params { + InferenceEngine::Precision src_prc; + InferenceEngine::Precision dst_prc; + size_t inner_work_amount; + size_t outter_work_amount; + bool with_scales; + bool broadcast_scales; + size_t inner_src_stride; + size_t outter_src_stride; + size_t outter_dst_stride; +}; + +struct jit_convert_transpose_call_args { + const void *p_in; + void *p_out; + const void *p_scales; +}; + +struct jit_uni_convert_transpose_kernel { + void (*ker_)(const jit_convert_transpose_call_args*); + + void operator()(const jit_convert_transpose_call_args* call_args) { + assert(ker_); + ker_(call_args); + } + + explicit jit_uni_convert_transpose_kernel(const jit_convert_transpose_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {} + virtual ~jit_uni_convert_transpose_kernel() {} + + virtual void create_ker() = 0; + + jit_convert_transpose_compile_params jcp_; +}; + +#define MHA_BRGEMM_KERNELS_NUM 8 + +class MHA : public Node { +public: + MHA(const std::shared_ptr& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache); + + void getSupportedDescriptors() override {}; + void initSupportedPrimitiveDescriptors() override; + void execute(dnnl::stream strm) override; + bool created() const override; + + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + +protected: + void executeDynamicImpl(dnnl::stream strm) override; + void prepareParams() override; + +private: + struct brgemmCtx { + size_t M, N, K, LDA, LDB, LDC; + dnnl_data_type_t dt_in0, dt_in1; + char palette[64]; + bool is_with_amx; + bool is_with_comp; + float beta; + }; + + template + void mhaImpl(); + + void init_brgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx); + void init_brgemm_copy_a(std::unique_ptr& brgCopyKernel, + size_t K, size_t K_blk, size_t K_tail, size_t LDA, dnnl_data_type_t dt_in0); + void init_brgemm_copy_b(std::unique_ptr& brgCopyKernel, + size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1); + + void callBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, + const void* pin0, const void* pin1, void* pout, void* wsp); + + size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) { + return mIdx * 4 + kIdx * 2 + nIdx; + } + + std::vector inputPrecisions; + InferenceEngine::Precision accPrecision0; + InferenceEngine::Precision accPrecision1; + + VectorDims dimsTranspose0In0; + VectorDims dimsTranspose1In0; + VectorDims dimsMulIn1; + VectorDims dimsAddIn1; + VectorDims dimsTranspose2In0; + VectorDims dimsOut; + + VectorDims strTranspose0In0; + VectorDims strTranspose1In0; + VectorDims strMulIn1; + VectorDims strAddIn1; + VectorDims strTranspose2In0; + VectorDims strOut; + + VectorDims dimsMatMul0In0; + VectorDims dimsMatMul0In1; + VectorDims dimsMatMul0Out; + VectorDims dimsMatMul1In1; + VectorDims dimsMatMul1Out; + + size_t batch0, batch1; + size_t M, M_blk, M_tail; + size_t K0, K0_blk, K0_tail, N0, N0_blk, N0_tail; + size_t K1, K1_blk, K1_tail, N1, N1_blk, N1_tail; + + size_t bufferMatMul0In0Size; + size_t bufferMatMul0In1Size; + size_t bufferMatMul0OutSize; + size_t bufferMatMul1In1Size; + size_t bufferMatMul1OutSize; + size_t bufferCompensation0Size; + size_t bufferCompensation1Size; + size_t wsp_size_per_thread = 4 * 1024; + + std::vector bufferMatMul0In0; + std::vector bufferMatMul0In1; + std::vector bufferMatMul0Out; + std::vector bufferMatMul1In1; + std::vector bufferMatMul1Out; + std::vector bufferCompensation0; + std::vector bufferCompensation1; + std::vector wsp; + + bool isMulFirst; + InferenceEngine::Precision fqPrc2; + + std::vector mulScales; + std::vector fqScales0; + std::vector fqScales1; + std::vector fqScales2; + std::vector fqScales3; + + size_t brg0VnniFactor; + brgemmCtx brgCtxs0[MHA_BRGEMM_KERNELS_NUM]; + std::unique_ptr brgKernels0[MHA_BRGEMM_KERNELS_NUM]; + std::unique_ptr brgCopyAKernel0; + std::unique_ptr brgCopyBKernel0; + + size_t brg1VnniFactor; + brgemmCtx brgCtxs1[MHA_BRGEMM_KERNELS_NUM]; + std::unique_ptr brgKernels1[MHA_BRGEMM_KERNELS_NUM]; + std::unique_ptr brgCopyBKernel1; + + std::unique_ptr mulAddSoftmaxKernel; + std::unique_ptr convertReorderKernel; + std::unique_ptr convertTransposeKernel; +}; + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index 2f41d3b935d..2f498cc4166 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -88,6 +88,7 @@ #include "nodes/priorbox.h" #include "nodes/priorbox_clustered.h" #include "nodes/eye.h" +#include "nodes/mha.h" namespace ov { namespace intel_cpu { @@ -188,6 +189,7 @@ Node::NodesFactory::NodesFactory() INTEL_CPU_NODE(PriorBox, Type::PriorBox); INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered); INTEL_CPU_NODE(Eye, Type::Eye); + INTEL_CPU_NODE(MHA, Type::MHA); } #undef INTEL_CPU_NODE diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 993a3ccc552..b548b693bf7 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -87,6 +87,7 @@ #include #include #include "transformations/op_conversions/eye_decomposition.hpp" +#include "ngraph_transformations/mha_fusion.hpp" #include #include @@ -118,6 +119,7 @@ #include "nodes/mvn.h" #include "nodes/fake_quantize.h" #include "nodes/normalize.h" +#include "nodes/mha.h" #include "ngraph_transformations/convert_to_cpu_specific_opset.hpp" #include "ngraph_transformations/move_eltwise_up_data_movement.hpp" #include "transformations/smart_reshape/smart_reshape.hpp" @@ -249,7 +251,7 @@ Engine::~Engine() { executorManager()->clear("CPUCallbackExecutor"); } -static void TransformationUpToCPUSpecificOpSet(std::shared_ptr nGraphFunc, const bool _enableLPT, +static void TransformationUpToCPUSpecificOpSet(std::shared_ptr nGraphFunc, const bool _enableLPT, const bool _enableBF16, const bool _enableSnippets, const bool isLegacyApi) { ngraph::pass::Manager manager; manager.set_per_pass_validation(false); @@ -600,6 +602,27 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr }); postLPTPassManager.register_pass(); + + // Snippets may brake MHA patterns so the fusion has to performed before + postLPTPassManager.register_pass(); + postLPTPassManager.get_pass_config()->set_callback([_enableBF16](const std::shared_ptr& n) -> bool { + std::string errorMessage; + + if (!node::MHA::isSupportedOperation(n, errorMessage)) + return true; + + // Implementation calls AMX BF16 brgemm only for tensors with K and N aligned on 2, otherwise fallbacks on vector impl + // Vector madd BF16 instruction on SPR has reduced performance on HW level, which results in overall perf degradation + size_t bf16Factor = 2; + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_bf16) && + (n->get_input_element_type(0) == element::bf16 || (n->get_input_element_type(0) == element::f32 && _enableBF16)) && + (n->get_input_shape(0)[3] % bf16Factor != 0 || n->get_input_shape(1)[1] % bf16Factor != 0 || n->get_input_shape(3)[3] % bf16Factor != 0)) { + return true; + } + + return false; + }); postLPTPassManager.run_passes(nGraphFunc); if (!useLpt && _enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) { @@ -631,9 +654,9 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr } } -static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableSnippets, const bool isLegacyApi) { +static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableBF16, const bool _enableSnippets, const bool isLegacyApi) { auto nGraphFunc = clonedNetwork.getFunction(); - TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableSnippets, isLegacyApi); + TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableBF16, _enableSnippets, isLegacyApi); ConvertToCPUSpecificOpset(nGraphFunc); } @@ -774,7 +797,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std || engConfig.enableDynamicBatch; const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16); auto nGraphFunc = clonedNetwork.getFunction(); - TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableSnippets, isLegacyAPI()); + TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableBF16, enableSnippets, isLegacyAPI()); // need to check that all outputs have static shapes // checking that all inputs have static shapes is performed in the common part @@ -1022,7 +1045,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma || Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */; const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))); - Transformation(clonedNetwork, enableLPT, enableSnippets, isLegacyAPI()); + Transformation(clonedNetwork, enableLPT, conf.enforceBF16, enableSnippets, isLegacyAPI()); auto ops = clonnedFunction->get_ordered_ops(); //Mark removed nodes as supported diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp new file mode 100644 index 00000000000..ba151dcaef7 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp @@ -0,0 +1,549 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "common_test_utils/common_utils.hpp" +#include +#include "functional_test_utils/skip_tests_config.hpp" +#include "test_utils/cpu_test_utils.hpp" + +using namespace CPUTestUtils; +using namespace ov::test; +using namespace ngraph::helpers; + +namespace CPUSubgraphTestsDefinitions { + +typedef std::tuple< + std::vector, // Input shapes + std::vector, // Input precisions + std::vector, // MatMul input #0 precisions + size_t, // pattern type # + std::string // Device name +> MHATuple; + +static std::shared_ptr initMHASubgraph0(std::vector& inputDynamicShapes, std::vector& inputPrecisions) { + ngraph::ParameterVector ngraphParam; + + auto transpose0Param = std::make_shared(inputPrecisions[0], inputDynamicShapes[0]); + ngraphParam.push_back(transpose0Param); + + auto transpose1Param = std::make_shared(inputPrecisions[1], inputDynamicShapes[1]); + ngraphParam.push_back(transpose1Param); + + auto addParam = std::make_shared(inputPrecisions[2], inputDynamicShapes[2]); + ngraphParam.push_back(addParam); + + auto transpose2Param = std::make_shared(inputPrecisions[3], inputDynamicShapes[3]); + ngraphParam.push_back(transpose2Param); + + std::vector constantShapes; + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({1, inputDynamicShapes[0].get_shape()[2], 1, 1})); + constantShapes.push_back(ov::Shape({2})); + constantShapes.push_back(ov::Shape({4})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + + std::vector transpose0ConstData = {0, 2, 1, 3}; + auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData); + + std::vector transpose1ConstData = {0, 2, 3, 1}; + auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData); + + std::vector mulConstData(ngraph::shape_size(constantShapes[2])); + auto mulConst = ngraph::builder::makeConstant(inputPrecisions[0], constantShapes[2], mulConstData, true); + + std::vector reshape0ConstData = {static_cast(inputDynamicShapes[0].get_shape()[0] * + inputDynamicShapes[0].get_shape()[1] * inputDynamicShapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(inputDynamicShapes[0].get_shape()[0]), + static_cast(inputDynamicShapes[0].get_shape()[2]), + static_cast(inputDynamicShapes[0].get_shape()[1]), + static_cast(inputDynamicShapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[4], reshape1ConstData); + + std::vector transpose2ConstData = {0, 2, 1, 3}; + auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[5], transpose2ConstData); + + std::vector transpose3ConstData = {0, 2, 1, 3}; + auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[6], transpose3ConstData); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(transpose0Param, transpose0Const); + const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); + const auto mul = std::make_shared(transpose1, mulConst); + const auto matMul0 = std::make_shared(transpose0, mul, transA, transB); + const auto add = std::make_shared(matMul0, addParam); + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto matMul1 = std::make_shared(reshape1, transpose2, transA, transB); + const auto transpose3 = std::make_shared(matMul1, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} + +static std::shared_ptr initMHASubgraph1(std::vector& inputDynamicShapes, std::vector& inputPrecisions) { + ngraph::ParameterVector ngraphParam; + + auto transpose0Param = std::make_shared(inputPrecisions[0], inputDynamicShapes[0]); + ngraphParam.push_back(transpose0Param); + + auto transpose1Param = std::make_shared(inputPrecisions[1], inputDynamicShapes[1]); + ngraphParam.push_back(transpose1Param); + + auto addParam = std::make_shared(inputPrecisions[2], inputDynamicShapes[2]); + ngraphParam.push_back(addParam); + + auto transpose2Param = std::make_shared(inputPrecisions[3], inputDynamicShapes[3]); + ngraphParam.push_back(transpose2Param); + + std::vector constantShapes; + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({1, inputDynamicShapes[0].get_shape()[2], 1, 1})); + constantShapes.push_back(ov::Shape({2})); + constantShapes.push_back(ov::Shape({4})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + + std::vector transpose0ConstData = {0, 2, 1, 3}; + auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData); + + std::vector transpose1ConstData = {0, 2, 3, 1}; + auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData); + + std::vector transpose2ConstData = {0, 2, 1, 3}; + auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose2ConstData); + + std::vector transpose3ConstData = {0, 2, 1, 3}; + auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose3ConstData); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(transpose0Param, transpose0Const); + const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); + const auto matMul0 = std::make_shared(transpose0, transpose1, transA, transB); + const auto add = std::make_shared(matMul0, addParam); + const auto softMax = std::make_shared(add, 3); + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto matMul1 = std::make_shared(softMax, transpose2, transA, transB); + const auto transpose3 = std::make_shared(matMul1, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} + +class MHATest : public testing::WithParamInterface, + virtual public SubgraphBaseTest, public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj) { + std::vector inputShapes; + std::vector inputPrecisions; + std::vector matMulIn0Precisions; + size_t patternType; + std::string targetName; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param; + std::ostringstream results; + + results << "IS=("; + for (const auto& shape : inputShapes) { + results << CommonTestUtils::partialShape2str({shape.first}) << "_"; + } + results << ")_TS=("; + for (const auto& shape : inputShapes) { + for (const auto& item : shape.second) { + results << CommonTestUtils::vec2str(item) << "_"; + } + } + for (int i = 0; i < inputPrecisions.size(); i++) { + results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_"; + } + results << "patternType=" << patternType; + results << "targetDevice=" << targetName; + + return results.str(); + } + + void generate_inputs(const std::vector& targetInputStaticShapes) override { + inputs.clear(); + const auto& funcInputs = function->inputs(); + for (int i = 0; i < funcInputs.size(); ++i) { + const auto& funcInput = funcInputs[i]; + ov::Tensor tensor; + tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(funcInput.get_element_type(), targetInputStaticShapes[i], 1.0f, 0.5f); + inputs.insert({funcInput.get_node_shared_ptr(), tensor}); + } + } + +protected: + void SetUp() override { + std::vector inputShapes; + std::vector inputPrecisions; + std::vector matMulIn0Precisions; + size_t patternType; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + + init_input_shapes(inputShapes); + + if (patternType == 0) { + function = initMHASubgraph0(inputDynamicShapes, inputPrecisions); + } else if (patternType == 1) { + function = initMHASubgraph1(inputDynamicShapes, inputPrecisions); + } else { + FAIL() << "Unsupported MHA pattern type"; + } + + // TODO: try better input data initialization to avoid threshold adjustment + // TODO: support different precisions on inputs + if (inputPrecisions[0] == ElementType::bf16) { + abs_threshold = 0.1f; + rel_threshold = 10.f; + + configuration.insert({{ InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16, InferenceEngine::PluginConfigParams::YES }}); + } + } +}; + +TEST_P(MHATest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + std::vector inputShapes; + std::vector inputPrecisions; + std::vector matMulIn0Precisions; + size_t patternType; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + + if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + + if (!InferenceEngine::with_cpu_x86_avx512_core()) + GTEST_SKIP(); + + run(); + CheckNumberOfNodesWithType(compiledModel, "MHA", 1); +} + +namespace { + +std::vector> inputShapes = { + {{2, 8, 16, 64}, {2, 8, 16, 64}, {2, 1, 1, 8}, {2, 8, 16, 64}}, + {{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, + {{2, 64, 16, 80}, {2, 64, 16, 80}, {2, 1, 1, 64}, {2, 64, 16, 80}}, + {{3, 96, 16, 64}, {3, 96, 16, 64}, {3, 1, 1, 96}, {3, 96, 16, 64}}, + {{2, 192, 16, 160}, {2, 192, 16, 160}, {2, 1, 1, 192}, {2, 192, 16, 160}}, + {{2, 4, 16, 8}, {2, 4, 16, 8}, {2, 1, 1, 4}, {2, 4, 16, 8}}, + {{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}}, +}; + +std::vector> inputPrecisions = { + { ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 }, + { ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 }, +}; + +std::vector> matMulIn0Precisions = { + {}, +}; + +std::vector patternTypes = { + 0, 1 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest, + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(matMulIn0Precisions), + ::testing::ValuesIn(patternTypes), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MHATest::getTestCaseName); + +} // namespace + +static std::shared_ptr initMHAQuantSubgraph0(std::vector& inputDynamicShapes, std::vector& inputPrecisions, + std::vector& matMulIn0Precisions) { + ngraph::ParameterVector ngraphParam; + + auto transpose0Param = std::make_shared(inputPrecisions[0], inputDynamicShapes[0]); + ngraphParam.push_back(transpose0Param); + + auto transpose1Param = std::make_shared(inputPrecisions[1], inputDynamicShapes[1]); + ngraphParam.push_back(transpose1Param); + + auto addParam = std::make_shared(inputPrecisions[2], inputDynamicShapes[2]); + ngraphParam.push_back(addParam); + + auto transpose2Param = std::make_shared(inputPrecisions[3], inputDynamicShapes[3]); + ngraphParam.push_back(transpose2Param); + + std::vector constantShapes; + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({2})); + constantShapes.push_back(ov::Shape({4})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + + std::vector transpose0ConstData = {0, 2, 1, 3}; + auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData); + + std::vector transpose1ConstData = {0, 2, 3, 1}; + auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData); + + std::vector reshape0ConstData = {static_cast(inputDynamicShapes[0].get_shape()[0] * + inputDynamicShapes[0].get_shape()[1] * inputDynamicShapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[2], reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(inputDynamicShapes[0].get_shape()[0]), + static_cast(inputDynamicShapes[0].get_shape()[2]), + static_cast(inputDynamicShapes[0].get_shape()[1]), + static_cast(inputDynamicShapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], reshape1ConstData); + + std::vector transpose2ConstData = {0, 2, 1, 3}; + auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[4], transpose2ConstData); + + std::vector transpose3ConstData = {0, 2, 1, 3}; + auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[5], transpose3ConstData); + + float transA = false; + float transB = false; + + std::shared_ptr fakeQuantize0; + if (matMulIn0Precisions[0] == ElementType::u8) + fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f}); + else + fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); + + const auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(transpose1Param, inputPrecisions[1], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); + const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(transpose2Param, inputPrecisions[3], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); + + std::shared_ptr fakeQuantize4; + + const auto transpose0 = std::make_shared(fakeQuantize0, transpose0Const); + const auto transpose1 = std::make_shared(fakeQuantize1, transpose1Const); + const auto matMul0 = std::make_shared(transpose0, transpose1, transA, transB); + const auto fakeQuantize3 = ngraph::builder::makeFakeQuantize(matMul0, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); + const auto add = std::make_shared(fakeQuantize3, addParam); + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + if (matMulIn0Precisions[1] == ElementType::u8) + fakeQuantize4 = ngraph::builder::makeFakeQuantize(reshape1, inputPrecisions[0], 256, {}, {0.0f}, {0.255f}, {0.0f}, {0.255f}); + else + fakeQuantize4 = ngraph::builder::makeFakeQuantize(reshape1, inputPrecisions[0], 256, {}, {-0.128f}, {0.127f}, {-0.128f}, {0.127f}); + const auto transpose2 = std::make_shared(fakeQuantize2, transpose2Const); + const auto matMul1 = std::make_shared(fakeQuantize4, transpose2, transA, transB); + const auto fakeQuantize5 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); + const auto transpose3 = std::make_shared(fakeQuantize5, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} + +static std::shared_ptr initMHAQuantSubgraph1(std::vector& inputDynamicShapes, std::vector& inputPrecisions, + std::vector& matMulIn0Precisions) { + ngraph::ParameterVector ngraphParam; + + auto transpose0Param = std::make_shared(inputPrecisions[0], inputDynamicShapes[0]); + ngraphParam.push_back(transpose0Param); + + auto transpose1Param = std::make_shared(inputPrecisions[1], inputDynamicShapes[1]); + ngraphParam.push_back(transpose1Param); + + auto addParam = std::make_shared(inputPrecisions[2], inputDynamicShapes[2]); + ngraphParam.push_back(addParam); + + auto transpose2Param = std::make_shared(inputPrecisions[3], inputDynamicShapes[3]); + ngraphParam.push_back(transpose2Param); + + std::vector constantShapes; + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({1})); + + std::vector transpose0ConstData = {0, 2, 1, 3}; + auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData); + + std::vector transpose1ConstData = {0, 2, 3, 1}; + auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData); + + std::vector transpose2ConstData = {0, 2, 1, 3}; + auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[2], transpose2ConstData); + + std::vector transpose3ConstData = {0, 2, 1, 3}; + auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], transpose3ConstData); + + std::vector mulConstData(ngraph::shape_size(constantShapes[4])); + auto mulConst = ngraph::builder::makeConstant(inputPrecisions[0], constantShapes[4], mulConstData, true); + + float transA = false; + float transB = false; + + std::shared_ptr fakeQuantize0; + if (matMulIn0Precisions[0] == ElementType::u8) + fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f}); + else + fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); + + const auto transpose0 = std::make_shared(fakeQuantize0, transpose0Const); + const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); + const auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(transpose1, inputPrecisions[1], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); + const auto matMul0 = std::make_shared(transpose0, fakeQuantize1, transA, transB); + const auto mul = std::make_shared(addParam, mulConst); + const auto add = std::make_shared(matMul0, mul); + const auto softMax = std::make_shared(add, 3); + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto matMul1 = std::make_shared(softMax, transpose2, transA, transB); + const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f}); + const auto transpose3 = std::make_shared(fakeQuantize2, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} + +class MHAQuantTest : public testing::WithParamInterface, + virtual public SubgraphBaseTest, public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj) { + std::vector inputShapes; + std::vector inputPrecisions; + std::vector matMulIn0Precisions; + size_t patternType; + std::string targetName; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param; + std::ostringstream results; + + results << "IS=("; + for (const auto& shape : inputShapes) { + results << CommonTestUtils::partialShape2str({shape.first}) << "_"; + } + results << ")_TS=("; + for (const auto& shape : inputShapes) { + for (const auto& item : shape.second) { + results << CommonTestUtils::vec2str(item) << "_"; + } + } + for (int i = 0; i < inputPrecisions.size(); i++) { + results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_"; + } + for (int i = 0; i < matMulIn0Precisions.size(); i++) { + results << "MatMulIn0PRC" << std::to_string(i) << "=" << matMulIn0Precisions[i] << "_"; + } + results << "patternType=" << patternType; + results << "targetDevice=" << targetName; + + return results.str(); + } + + void generate_inputs(const std::vector& targetInputStaticShapes) override { + inputs.clear(); + const auto& funcInputs = function->inputs(); + for (int i = 0; i < funcInputs.size(); ++i) { + const auto& funcInput = funcInputs[i]; + ov::Tensor tensor; + if (funcInput.get_element_type().is_real()) + tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(funcInput.get_element_type(), targetInputStaticShapes[i], 0.0f, 1.5f); + else + tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 255, 0, 1); + + + inputs.insert({funcInput.get_node_shared_ptr(), tensor}); + } + } + +protected: + void SetUp() override { + abs_threshold = 0.1f; + + std::vector inputShapes; + std::vector inputPrecisions; + std::vector matMulIn0Precisions; + size_t patternType; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + + init_input_shapes(inputShapes); + + if (patternType == 0) { + function = initMHAQuantSubgraph0(inputDynamicShapes, inputPrecisions, matMulIn0Precisions); + } else if (patternType == 1) { + function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions); + } else { + FAIL() << "Unsupported MHA pattern type"; + } + } +}; + +TEST_P(MHAQuantTest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + std::vector inputShapes; + std::vector inputPrecisions; + std::vector matMulIn0Precisions; + size_t patternType; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + + if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + + if (!InferenceEngine::with_cpu_x86_avx512_core_vnni()) + GTEST_SKIP(); + + run(); + CheckNumberOfNodesWithType(compiledModel, "MHA", 1); +} + +namespace { + +std::vector> inputShapesQuant = { + {{2, 7, 16, 9}, {2, 7, 16, 9}, {2, 1, 1, 7}, {2, 7, 16, 9}}, + {{2, 8, 16, 64}, {2, 8, 16, 64}, {2, 1, 1, 8}, {2, 8, 16, 64}}, + {{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, + {{2, 64, 16, 80}, {2, 64, 16, 80}, {2, 1, 1, 64}, {2, 64, 16, 80}}, + {{3, 96, 16, 64}, {3, 96, 16, 64}, {3, 1, 1, 96}, {3, 96, 16, 64}}, + {{2, 192, 16, 160}, {2, 192, 16, 160}, {2, 1, 1, 192}, {2, 192, 16, 160}}, + {{2, 4, 16, 8}, {2, 4, 16, 8}, {2, 1, 1, 4}, {2, 4, 16, 8}}, + {{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}}, + {{1, 207, 13, 211}, {1, 207, 13, 211}, {1, 1, 1, 207}, {1, 207, 13, 211}}, +}; + +std::vector> inputPrecisionsQuant = { + { ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 }, +}; + +std::vector> matMulIn0PrecisionsQuant = { + { ElementType::i8, ElementType::i8 }, + { ElementType::i8, ElementType::u8 }, +}; + +std::vector patternTypesQuant = { + 0, 1 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest, + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesQuant)), + ::testing::ValuesIn(inputPrecisionsQuant), + ::testing::ValuesIn(matMulIn0PrecisionsQuant), + ::testing::ValuesIn(patternTypesQuant), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MHAQuantTest::getTestCaseName); + +} // namespace +} // namespace CPUSubgraphTestsDefinitions