From a032d67cc780d3dfc2f8dc4677bf8d874fee46ce Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 26 Apr 2023 16:13:01 +0400 Subject: [PATCH] [CPU] Fixed enforcebf16 condition for transformation pipeline (#17157) * [CPU] Fixed enforcebf16 condition for transformation pipeline * [Snippets][CPU][Tests] Added test with bf16 --- src/plugins/intel_cpu/src/plugin.cpp | 19 +++--- .../shared_tests_instances/snippets/mha.cpp | 27 +++++++- .../plugin/shared/include/snippets/mha.hpp | 19 +++--- .../plugin/shared/src/snippets/mha.cpp | 63 +++++++++++++++++-- .../include/subgraph_mha.hpp | 18 ++++++ .../src/subgraph_mha.cpp | 16 +++++ 6 files changed, 139 insertions(+), 23 deletions(-) diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index d82e0fb152f..7705efc3c3a 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -382,15 +382,18 @@ StreamCfg Engine::GetNumStreams(InferenceEngine::IStreamsExecutor::ThreadBinding } static bool shouldEnforceBF16(const std::map& modelConfig, const Config& engineConfig) { - const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16); - if (enforceBF16 == modelConfig.end()) { // not set for the model - return engineConfig.enforceBF16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // use value from engine - } - - if (enforceBF16->second == PluginConfigParams::YES) { - return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); - } else { + // For BF16 execution, the machine should have AVX512 at least + if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) return false; + + const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16); + const auto& inferPrec = modelConfig.find(ov::hint::inference_precision.name()); + if (enforceBF16 != modelConfig.end()) { + return enforceBF16->second == PluginConfigParams::YES; + } else if (inferPrec != modelConfig.end()) { + return inferPrec->second == "bf16"; + } else { + return engineConfig.enforceBF16; } } diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 11aeaebdcc2..13c3063d0d1 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -4,6 +4,7 @@ #include "snippets/mha.hpp" #include "common_test_utils/test_constants.hpp" +#include "ie_plugin_config.hpp" namespace ov { namespace test { @@ -24,9 +25,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA, ::testing::Combine( ::testing::ValuesIn(inputShapes), ::testing::ValuesIn({false, true}), + ::testing::Values(ov::element::f32), ::testing::Values(1), ::testing::Values(1), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(std::map{})), MHA::getTestCaseName); const std::vector> inputShapeSelect = { @@ -42,9 +45,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect, ::testing::Combine( ::testing::ValuesIn(inputShapeSelect), ::testing::Values(false), // Need to support True for graph builder in tests + ::testing::Values(ov::element::f32), ::testing::Values(2), // Less + MHA ::testing::Values(2), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(std::map{})), MHA::getTestCaseName); const std::vector> inputShapesWOTranspose = { @@ -55,9 +60,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOn ::testing::Combine( ::testing::ValuesIn(inputShapesWOTranspose), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(ov::element::f32), ::testing::Values(1), ::testing::Values(1), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(std::map{})), + MHA::getTestCaseName); + +const std::map cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16, + InferenceEngine::PluginConfigParams::YES } }; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose, + ::testing::Combine( + ::testing::ValuesIn(inputShapesWOTranspose), + ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(ov::element::bf16), + ::testing::Values(3), + ::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16 + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(cpuBF16PluginConfig)), MHA::getTestCaseName); diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index 9f95dcc30ac..7794c9b286d 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -11,16 +11,17 @@ namespace test { namespace snippets { typedef std::tuple< - std::vector, // Input shapes - bool, // With Multiply - size_t, // Expected num nodes - size_t, // Expected num subgraphs - std::string // Target Device + std::vector, // Input shapes + bool, // With Multiply + ov::element::Type, // Inference precision + size_t, // Expected num nodes + size_t, // Expected num subgraphs + std::string, // Target Device + std::map // Config > MHAParams; - class MHA : public testing::WithParamInterface, - virtual public ov::test::SnippetsTestsCommon { + virtual public ov::test::SnippetsTestsCommon { public: static std::string getTestCaseName(testing::TestParamInfo obj); @@ -42,6 +43,10 @@ protected: void SetUp() override; }; +class MHAWOTranspose : public MHA { + void SetUp() override; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index cf0075906c0..714038ac726 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -16,33 +16,50 @@ namespace snippets { std::string MHA::getTestCaseName(testing::TestParamInfo obj) { std::vector inputShapes; bool withMul; + ov::element::Type prc; std::string targetDevice; size_t num_nodes, num_subgraphs; - std::tie(inputShapes, withMul, num_nodes, num_subgraphs, targetDevice) = obj.param; + std::map additionalConfig; + std::tie(inputShapes, withMul, prc, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param; std::ostringstream result; for (size_t i = 0; i < inputShapes.size(); ++i) result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({inputShapes[i]}) << "_"; result << "Mul=" << withMul << "_"; + result << "PRC=" << prc << "_"; result << "#N=" << num_nodes << "_"; result << "#S=" << num_subgraphs << "_"; - result << "targetDevice=" << targetDevice; + result << "targetDevice=" << targetDevice << "_"; + + if (!additionalConfig.empty()) { + result << "_PluginConf"; + for (auto &item : additionalConfig) { + if (item.second == InferenceEngine::PluginConfigParams::YES) + result << "_" << item.first << "=" << item.second; + } + } return result.str(); } void MHA::SetUp() { std::vector inputShapes; bool withMul; - std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + ov::element::Type prc; + std::map additionalConfig; + std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, withMul); function = f.getOriginal(); + configuration.insert(additionalConfig.begin(), additionalConfig.end()); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); } + + setInferenceType(prc); + inType = outType = prc; } void MHA::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -59,16 +76,22 @@ void MHA::generate_inputs(const std::vector& targetInputStaticSha void MHASelect::SetUp() { std::vector inputShapes; bool withMul; - std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + ov::element::Type prc; + std::map additionalConfig; + std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes); function = f.getOriginal(); + configuration.insert(additionalConfig.begin(), additionalConfig.end()); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); } + + setInferenceType(prc); + inType = outType = prc; } void MHASelect::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -92,16 +115,41 @@ void MHASelect::generate_inputs(const std::vector& targetInputSta void MHAWOTransposeOnInputs::SetUp() { std::vector inputShapes; bool withMul; - std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + ov::element::Type prc; + std::map additionalConfig; + std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes); function = f.getOriginal(); + configuration.insert(additionalConfig.begin(), additionalConfig.end()); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); } + + setInferenceType(prc); + inType = outType = prc; +} + +void MHAWOTranspose::SetUp() { + std::vector inputShapes; + bool withMul; + ov::element::Type prc; + std::map additionalConfig; + std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); + + auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes); + function = f.getOriginal(); + + configuration.insert(additionalConfig.begin(), additionalConfig.end()); + + setInferenceType(prc); + inType = outType = prc; + if (prc == ov::element::bf16) + abs_threshold = 0.3; } @@ -120,6 +168,11 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) { validateNumSubgraphs(); } +TEST_P(MHAWOTranspose, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + } // namespace snippets } // namespace test diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp index 309a32e9145..5f4ceebf599 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp @@ -126,6 +126,24 @@ protected: std::shared_ptr initOriginal() const override; }; +/* Graph: + * \ / + * MatMul0 + * | + * Softmax + * \ / + * MatMul1 + * | + */ +class MHAWOTransposeFunction : public SnippetsFunctionBase { +public: + explicit MHAWOTransposeFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index ac38ea47624..b3b3c8d0f9b 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -343,6 +343,22 @@ std::shared_ptr MHAWOTransposeOnInputsFunction::initOriginal() const return std::make_shared(results, ngraphParam, "mha"); } +std::shared_ptr MHAWOTransposeFunction::initOriginal() const { + auto param0 = std::make_shared(precision, input_shapes[0]); + auto param1 = std::make_shared(precision, input_shapes[1]); + auto param2 = std::make_shared(precision, input_shapes[2]); + ngraph::ParameterVector ngraphParam = {param0, param1, param2}; + + float transA = false; + float transB = false; + const auto matMul0 = std::make_shared(param0, param1, transA, transB); + const auto softmax = std::make_shared(matMul0, 3); + const auto matMul1 = std::make_shared(softmax, param2, transA, transB); + + ngraph::ResultVector results{std::make_shared(matMul1)}; + return std::make_shared(results, ngraphParam, "mha"); +} + } // namespace snippets } // namespace test } // namespace ov