diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index 6898a423b5f..fbca7b47b62 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -6,6 +6,7 @@ #include "snippets/itt.hpp" +#include "snippets/utils.hpp" #include "snippets/pass/tokenization.hpp" #include "snippets/op/subgraph.hpp" @@ -18,8 +19,8 @@ namespace { auto is_supported_tensor(const ov::descriptor::Tensor& t) -> bool { // TODO: Add support of all supported by common tokenization element types // return ov::snippets::pass::TokenizeSnippets::supported_element_types.count(input.get_element_type()) != 0; - // Also only 4D is supported at the moment - return t.get_element_type() == ov::element::f32 && t.get_partial_shape().is_static() && t.get_shape().size() == 4; + return t.get_element_type() == ngraph::element::f32 && + t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu); } // TODO: Add support of FQ, Reshape? diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp index cf926fdd07c..6c4717699b4 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp @@ -759,7 +759,6 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: m_K = A_shape[get_ordered_idx(A_layout, A_layout.size() - 1)]; m_M_blk = matmulOptimalM; m_M_tail = m_M % m_M_blk; - // B_shape[B_layout[3]] m_N = C_shape[get_ordered_idx(C_layout, C_layout.size() - 1)]; auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0)); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 8778c35ab1d..96baa1f9fb6 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -52,14 +52,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect, ::testing::Values(std::map{})), MHA::getTestCaseName); -const std::vector> inputShapesWOTranspose = { - {{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}}, - {{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}} -}; + +static std::vector> inputShapesWOTranspose(bool supports_3d = false) { + std::vector> shapes = { + {{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}}, + {{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}} + }; + if (supports_3d) { + std::vector> shapes_3d = { + {{12, 197, 64}, {12, 64, 197}, {12, 197, 64}}, + {{12, 128, 100}, {12, 100, 128}, {12, 128, 100}} + }; + shapes.insert(shapes.end(), shapes_3d.begin(), shapes_3d.end()); + } + return shapes; +} INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOnInputs, ::testing::Combine( - ::testing::ValuesIn(inputShapesWOTranspose), + ::testing::ValuesIn(inputShapesWOTranspose()), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests ::testing::Values(ov::element::f32), ::testing::Values(1), @@ -73,7 +84,7 @@ const std::map cpuBF16PluginConfig = { { InferenceEngi INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose, ::testing::Combine( - ::testing::ValuesIn(inputShapesWOTranspose), + ::testing::ValuesIn(inputShapesWOTranspose(true)), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests ::testing::Values(ov::element::bf16), ::testing::Values(3), @@ -82,6 +93,17 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose, ::testing::Values(cpuBF16PluginConfig)), MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose, MHAWOTranspose, + ::testing::Combine( + ::testing::ValuesIn(inputShapesWOTranspose(true)), + ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(ov::element::f32), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(std::map{})), + MHA::getTestCaseName); + } // namespace } // namespace snippets diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index d904790bce2..0f64f9cf561 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -335,7 +335,7 @@ std::shared_ptr MHAWOTransposeOnInputsFunction::initOriginal() const const auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape({1}), std::vector{1}, true); const auto mul = std::make_shared(param1, mulConst); const auto matMul0 = std::make_shared(param0, mul, transA, transB); - const auto softmax = std::make_shared(matMul0, 3); + const auto softmax = std::make_shared(matMul0, -1); const auto matMul1 = std::make_shared(softmax, param2, transA, transB); const auto transpose3 = std::make_shared(matMul1, transpose3Const); @@ -352,7 +352,7 @@ std::shared_ptr MHAWOTransposeFunction::initOriginal() const { float transA = false; float transB = false; const auto matMul0 = std::make_shared(param0, param1, transA, transB); - const auto softmax = std::make_shared(matMul0, 3); + const auto softmax = std::make_shared(matMul0, -1); const auto matMul1 = std::make_shared(softmax, param2, transA, transB); ngraph::ResultVector results{std::make_shared(matMul1)};