[Snippets] Added support of 3D MHA (#17817)
This commit is contained in:
parent
29f06692d6
commit
2ec9fe915c
@ -6,6 +6,7 @@
|
||||
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
|
||||
@ -18,8 +19,8 @@ namespace {
|
||||
auto is_supported_tensor(const ov::descriptor::Tensor& t) -> bool {
|
||||
// TODO: Add support of all supported by common tokenization element types
|
||||
// return ov::snippets::pass::TokenizeSnippets::supported_element_types.count(input.get_element_type()) != 0;
|
||||
// Also only 4D is supported at the moment
|
||||
return t.get_element_type() == ov::element::f32 && t.get_partial_shape().is_static() && t.get_shape().size() == 4;
|
||||
return t.get_element_type() == ngraph::element::f32 &&
|
||||
t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu);
|
||||
}
|
||||
|
||||
// TODO: Add support of FQ, Reshape?
|
||||
|
@ -759,7 +759,6 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
m_K = A_shape[get_ordered_idx(A_layout, A_layout.size() - 1)];
|
||||
m_M_blk = matmulOptimalM;
|
||||
m_M_tail = m_M % m_M_blk;
|
||||
// B_shape[B_layout[3]]
|
||||
m_N = C_shape[get_ordered_idx(C_layout, C_layout.size() - 1)];
|
||||
|
||||
auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0));
|
||||
|
@ -52,14 +52,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose = {
|
||||
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
|
||||
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
|
||||
};
|
||||
|
||||
static std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose(bool supports_3d = false) {
|
||||
std::vector<std::vector<ov::PartialShape>> shapes = {
|
||||
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
|
||||
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
|
||||
};
|
||||
if (supports_3d) {
|
||||
std::vector<std::vector<ov::PartialShape>> shapes_3d = {
|
||||
{{12, 197, 64}, {12, 64, 197}, {12, 197, 64}},
|
||||
{{12, 128, 100}, {12, 100, 128}, {12, 128, 100}}
|
||||
};
|
||||
shapes.insert(shapes.end(), shapes_3d.begin(), shapes_3d.end());
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOnInputs,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose),
|
||||
::testing::ValuesIn(inputShapesWOTranspose()),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
@ -73,7 +84,7 @@ const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngi
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose),
|
||||
::testing::ValuesIn(inputShapesWOTranspose(true)),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::bf16),
|
||||
::testing::Values(3),
|
||||
@ -82,6 +93,17 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
|
||||
::testing::Values(cpuBF16PluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose(true)),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
|
||||
} // namespace
|
||||
} // namespace snippets
|
||||
|
@ -335,7 +335,7 @@ std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const
|
||||
const auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape({1}), std::vector<float>{1}, true);
|
||||
const auto mul = std::make_shared<ngraph::opset3::Multiply>(param1, mulConst);
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, mul, transA, transB);
|
||||
const auto softmax = std::make_shared<ngraph::opset1::Softmax>(matMul0, 3);
|
||||
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2, transA, transB);
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
|
||||
|
||||
@ -352,7 +352,7 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, param1, transA, transB);
|
||||
const auto softmax = std::make_shared<ngraph::opset1::Softmax>(matMul0, 3);
|
||||
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2, transA, transB);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};
|
||||
|
Loading…
Reference in New Issue
Block a user