[CPU] Fixed enforcebf16 condition for transformation pipeline (#17157)
* [CPU] Fixed enforcebf16 condition for transformation pipeline * [Snippets][CPU][Tests] Added test with bf16
This commit is contained in:
committed by
GitHub
parent
ca92eb96ad
commit
a032d67cc7
@@ -382,15 +382,18 @@ StreamCfg Engine::GetNumStreams(InferenceEngine::IStreamsExecutor::ThreadBinding
|
||||
}
|
||||
|
||||
static bool shouldEnforceBF16(const std::map<std::string, std::string>& modelConfig, const Config& engineConfig) {
|
||||
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
|
||||
if (enforceBF16 == modelConfig.end()) { // not set for the model
|
||||
return engineConfig.enforceBF16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // use value from engine
|
||||
}
|
||||
|
||||
if (enforceBF16->second == PluginConfigParams::YES) {
|
||||
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core);
|
||||
} else {
|
||||
// For BF16 execution, the machine should have AVX512 at least
|
||||
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
|
||||
return false;
|
||||
|
||||
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
|
||||
const auto& inferPrec = modelConfig.find(ov::hint::inference_precision.name());
|
||||
if (enforceBF16 != modelConfig.end()) {
|
||||
return enforceBF16->second == PluginConfigParams::YES;
|
||||
} else if (inferPrec != modelConfig.end()) {
|
||||
return inferPrec->second == "bf16";
|
||||
} else {
|
||||
return engineConfig.enforceBF16;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "snippets/mha.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
#include "ie_plugin_config.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@@ -24,9 +25,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn({false, true}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
|
||||
@@ -42,9 +45,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapeSelect),
|
||||
::testing::Values(false), // Need to support True for graph builder in tests
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(2), // Less + MHA
|
||||
::testing::Values(2),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose = {
|
||||
@@ -55,9 +60,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOn
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16,
|
||||
InferenceEngine::PluginConfigParams::YES } };
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::bf16),
|
||||
::testing::Values(3),
|
||||
::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(cpuBF16PluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user