[CPU] Fixed enforcebf16 condition for transformation pipeline (#17157)

* [CPU] Fixed enforcebf16 condition for transformation pipeline

* [Snippets][CPU][Tests] Added test with bf16
This commit is contained in:
Alexandra Sidorova
2023-04-26 16:13:01 +04:00
committed by GitHub
parent ca92eb96ad
commit a032d67cc7
6 changed files with 139 additions and 23 deletions

View File

@@ -382,15 +382,18 @@ StreamCfg Engine::GetNumStreams(InferenceEngine::IStreamsExecutor::ThreadBinding
}
static bool shouldEnforceBF16(const std::map<std::string, std::string>& modelConfig, const Config& engineConfig) {
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
if (enforceBF16 == modelConfig.end()) { // not set for the model
return engineConfig.enforceBF16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // use value from engine
}
if (enforceBF16->second == PluginConfigParams::YES) {
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core);
} else {
// For BF16 execution, the machine should have AVX512 at least
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
return false;
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
const auto& inferPrec = modelConfig.find(ov::hint::inference_precision.name());
if (enforceBF16 != modelConfig.end()) {
return enforceBF16->second == PluginConfigParams::YES;
} else if (inferPrec != modelConfig.end()) {
return inferPrec->second == "bf16";
} else {
return engineConfig.enforceBF16;
}
}

View File

@@ -4,6 +4,7 @@
#include "snippets/mha.hpp"
#include "common_test_utils/test_constants.hpp"
#include "ie_plugin_config.hpp"
namespace ov {
namespace test {
@@ -24,9 +25,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn({false, true}),
::testing::Values(ov::element::f32),
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
@@ -42,9 +45,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
::testing::Combine(
::testing::ValuesIn(inputShapeSelect),
::testing::Values(false), // Need to support True for graph builder in tests
::testing::Values(ov::element::f32),
::testing::Values(2), // Less + MHA
::testing::Values(2),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose = {
@@ -55,9 +60,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOn
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::f32),
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16,
InferenceEngine::PluginConfigParams::YES } };
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::bf16),
::testing::Values(3),
::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(cpuBF16PluginConfig)),
MHA::getTestCaseName);