[Snippets] Added support of 3D MHA (#17817)

This commit is contained in:
Alexandra Sidorova 2023-06-06 16:55:01 +04:00 committed by GitHub
parent 29f06692d6
commit 2ec9fe915c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 11 deletions

View File

@ -6,6 +6,7 @@
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"
#include "snippets/pass/tokenization.hpp"
#include "snippets/op/subgraph.hpp"
@ -18,8 +19,8 @@ namespace {
auto is_supported_tensor(const ov::descriptor::Tensor& t) -> bool {
// TODO: Add support of all supported by common tokenization element types
// return ov::snippets::pass::TokenizeSnippets::supported_element_types.count(input.get_element_type()) != 0;
// Also only 4D is supported at the moment
return t.get_element_type() == ov::element::f32 && t.get_partial_shape().is_static() && t.get_shape().size() == 4;
return t.get_element_type() == ngraph::element::f32 &&
t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu);
}
// TODO: Add support of FQ, Reshape?

View File

@ -759,7 +759,6 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
m_K = A_shape[get_ordered_idx(A_layout, A_layout.size() - 1)];
m_M_blk = matmulOptimalM;
m_M_tail = m_M % m_M_blk;
// B_shape[B_layout[3]]
m_N = C_shape[get_ordered_idx(C_layout, C_layout.size() - 1)];
auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0));

View File

@ -52,14 +52,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose = {
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
};
static std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose(bool supports_3d = false) {
std::vector<std::vector<ov::PartialShape>> shapes = {
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
};
if (supports_3d) {
std::vector<std::vector<ov::PartialShape>> shapes_3d = {
{{12, 197, 64}, {12, 64, 197}, {12, 197, 64}},
{{12, 128, 100}, {12, 100, 128}, {12, 128, 100}}
};
shapes.insert(shapes.end(), shapes_3d.begin(), shapes_3d.end());
}
return shapes;
}
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOnInputs,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose),
::testing::ValuesIn(inputShapesWOTranspose()),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::f32),
::testing::Values(1),
@ -73,7 +84,7 @@ const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngi
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose),
::testing::ValuesIn(inputShapesWOTranspose(true)),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::bf16),
::testing::Values(3),
@ -82,6 +93,17 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
::testing::Values(cpuBF16PluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose(true)),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::f32),
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
} // namespace
} // namespace snippets

View File

@ -335,7 +335,7 @@ std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const
const auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape({1}), std::vector<float>{1}, true);
const auto mul = std::make_shared<ngraph::opset3::Multiply>(param1, mulConst);
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, mul, transA, transB);
const auto softmax = std::make_shared<ngraph::opset1::Softmax>(matMul0, 3);
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2, transA, transB);
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
@ -352,7 +352,7 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
float transA = false;
float transB = false;
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, param1, transA, transB);
const auto softmax = std::make_shared<ngraph::opset1::Softmax>(matMul0, 3);
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2, transA, transB);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};