[CPU] Fixed enforcebf16 condition for transformation pipeline (#17157)

* [CPU] Fixed enforcebf16 condition for transformation pipeline

* [Snippets][CPU][Tests] Added test with bf16
This commit is contained in:
Alexandra Sidorova
2023-04-26 16:13:01 +04:00
committed by GitHub
parent ca92eb96ad
commit a032d67cc7
6 changed files with 139 additions and 23 deletions

View File

@@ -382,15 +382,18 @@ StreamCfg Engine::GetNumStreams(InferenceEngine::IStreamsExecutor::ThreadBinding
}
static bool shouldEnforceBF16(const std::map<std::string, std::string>& modelConfig, const Config& engineConfig) {
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
if (enforceBF16 == modelConfig.end()) { // not set for the model
return engineConfig.enforceBF16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // use value from engine
}
if (enforceBF16->second == PluginConfigParams::YES) {
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core);
} else {
// For BF16 execution, the machine should have AVX512 at least
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
return false;
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
const auto& inferPrec = modelConfig.find(ov::hint::inference_precision.name());
if (enforceBF16 != modelConfig.end()) {
return enforceBF16->second == PluginConfigParams::YES;
} else if (inferPrec != modelConfig.end()) {
return inferPrec->second == "bf16";
} else {
return engineConfig.enforceBF16;
}
}

View File

@@ -4,6 +4,7 @@
#include "snippets/mha.hpp"
#include "common_test_utils/test_constants.hpp"
#include "ie_plugin_config.hpp"
namespace ov {
namespace test {
@@ -24,9 +25,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn({false, true}),
::testing::Values(ov::element::f32),
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
@@ -42,9 +45,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
::testing::Combine(
::testing::ValuesIn(inputShapeSelect),
::testing::Values(false), // Need to support True for graph builder in tests
::testing::Values(ov::element::f32),
::testing::Values(2), // Less + MHA
::testing::Values(2),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose = {
@@ -55,9 +60,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOn
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::f32),
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16,
InferenceEngine::PluginConfigParams::YES } };
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::bf16),
::testing::Values(3),
::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(cpuBF16PluginConfig)),
MHA::getTestCaseName);

View File

@@ -11,16 +11,17 @@ namespace test {
namespace snippets {
typedef std::tuple<
std::vector<ov::PartialShape>, // Input shapes
bool, // With Multiply
size_t, // Expected num nodes
size_t, // Expected num subgraphs
std::string // Target Device
std::vector<ov::PartialShape>, // Input shapes
bool, // With Multiply
ov::element::Type, // Inference precision
size_t, // Expected num nodes
size_t, // Expected num subgraphs
std::string, // Target Device
std::map<std::string, std::string> // Config
> MHAParams;
class MHA : public testing::WithParamInterface<ov::test::snippets::MHAParams>,
virtual public ov::test::SnippetsTestsCommon {
virtual public ov::test::SnippetsTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj);
@@ -42,6 +43,10 @@ protected:
void SetUp() override;
};
class MHAWOTranspose : public MHA {
void SetUp() override;
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@@ -16,33 +16,50 @@ namespace snippets {
std::string MHA::getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj) {
std::vector<ov::PartialShape> inputShapes;
bool withMul;
ov::element::Type prc;
std::string targetDevice;
size_t num_nodes, num_subgraphs;
std::tie(inputShapes, withMul, num_nodes, num_subgraphs, targetDevice) = obj.param;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, withMul, prc, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param;
std::ostringstream result;
for (size_t i = 0; i < inputShapes.size(); ++i)
result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({inputShapes[i]}) << "_";
result << "Mul=" << withMul << "_";
result << "PRC=" << prc << "_";
result << "#N=" << num_nodes << "_";
result << "#S=" << num_subgraphs << "_";
result << "targetDevice=" << targetDevice;
result << "targetDevice=" << targetDevice << "_";
if (!additionalConfig.empty()) {
result << "_PluginConf";
for (auto &item : additionalConfig) {
if (item.second == InferenceEngine::PluginConfigParams::YES)
result << "_" << item.first << "=" << item.second;
}
}
return result.str();
}
void MHA::SetUp() {
std::vector<ov::PartialShape> inputShapes;
bool withMul;
std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
ov::element::Type prc;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, withMul);
function = f.getOriginal();
configuration.insert(additionalConfig.begin(), additionalConfig.end());
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
}
setInferenceType(prc);
inType = outType = prc;
}
void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
@@ -59,16 +76,22 @@ void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticSha
void MHASelect::SetUp() {
std::vector<ov::PartialShape> inputShapes;
bool withMul;
std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
ov::element::Type prc;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes);
function = f.getOriginal();
configuration.insert(additionalConfig.begin(), additionalConfig.end());
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
}
setInferenceType(prc);
inType = outType = prc;
}
void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
@@ -92,16 +115,41 @@ void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputSta
void MHAWOTransposeOnInputs::SetUp() {
std::vector<ov::PartialShape> inputShapes;
bool withMul;
std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
ov::element::Type prc;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes);
function = f.getOriginal();
configuration.insert(additionalConfig.begin(), additionalConfig.end());
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
}
setInferenceType(prc);
inType = outType = prc;
}
void MHAWOTranspose::SetUp() {
std::vector<ov::PartialShape> inputShapes;
bool withMul;
ov::element::Type prc;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes);
function = f.getOriginal();
configuration.insert(additionalConfig.begin(), additionalConfig.end());
setInferenceType(prc);
inType = outType = prc;
if (prc == ov::element::bf16)
abs_threshold = 0.3;
}
@@ -120,6 +168,11 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) {
validateNumSubgraphs();
}
TEST_P(MHAWOTranspose, CompareWithRefImpl) {
run();
validateNumSubgraphs();
}
} // namespace snippets
} // namespace test

View File

@@ -126,6 +126,24 @@ protected:
std::shared_ptr<ov::Model> initOriginal() const override;
};
/* Graph:
* \ /
* MatMul0
* |
* Softmax
* \ /
* MatMul1
* |
*/
class MHAWOTransposeFunction : public SnippetsFunctionBase {
public:
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@@ -343,6 +343,22 @@ std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
float transA = false;
float transB = false;
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, param1, transA, transB);
const auto softmax = std::make_shared<ngraph::opset1::Softmax>(matMul0, 3);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2, transA, transB);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
} // namespace snippets
} // namespace test
} // namespace ov