[CPU] Fixed enforcebf16 condition for transformation pipeline (#17157)
* [CPU] Fixed enforcebf16 condition for transformation pipeline * [Snippets][CPU][Tests] Added test with bf16
This commit is contained in:
committed by
GitHub
parent
ca92eb96ad
commit
a032d67cc7
@@ -382,15 +382,18 @@ StreamCfg Engine::GetNumStreams(InferenceEngine::IStreamsExecutor::ThreadBinding
|
||||
}
|
||||
|
||||
static bool shouldEnforceBF16(const std::map<std::string, std::string>& modelConfig, const Config& engineConfig) {
|
||||
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
|
||||
if (enforceBF16 == modelConfig.end()) { // not set for the model
|
||||
return engineConfig.enforceBF16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // use value from engine
|
||||
}
|
||||
|
||||
if (enforceBF16->second == PluginConfigParams::YES) {
|
||||
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core);
|
||||
} else {
|
||||
// For BF16 execution, the machine should have AVX512 at least
|
||||
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
|
||||
return false;
|
||||
|
||||
const auto& enforceBF16 = modelConfig.find(InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16);
|
||||
const auto& inferPrec = modelConfig.find(ov::hint::inference_precision.name());
|
||||
if (enforceBF16 != modelConfig.end()) {
|
||||
return enforceBF16->second == PluginConfigParams::YES;
|
||||
} else if (inferPrec != modelConfig.end()) {
|
||||
return inferPrec->second == "bf16";
|
||||
} else {
|
||||
return engineConfig.enforceBF16;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "snippets/mha.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
#include "ie_plugin_config.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@@ -24,9 +25,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn({false, true}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
|
||||
@@ -42,9 +45,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapeSelect),
|
||||
::testing::Values(false), // Need to support True for graph builder in tests
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(2), // Less + MHA
|
||||
::testing::Values(2),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose = {
|
||||
@@ -55,9 +60,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOn
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16,
|
||||
InferenceEngine::PluginConfigParams::YES } };
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::bf16),
|
||||
::testing::Values(3),
|
||||
::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(cpuBF16PluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
|
||||
|
||||
@@ -11,16 +11,17 @@ namespace test {
|
||||
namespace snippets {
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<ov::PartialShape>, // Input shapes
|
||||
bool, // With Multiply
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string // Target Device
|
||||
std::vector<ov::PartialShape>, // Input shapes
|
||||
bool, // With Multiply
|
||||
ov::element::Type, // Inference precision
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string, // Target Device
|
||||
std::map<std::string, std::string> // Config
|
||||
> MHAParams;
|
||||
|
||||
|
||||
class MHA : public testing::WithParamInterface<ov::test::snippets::MHAParams>,
|
||||
virtual public ov::test::SnippetsTestsCommon {
|
||||
virtual public ov::test::SnippetsTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj);
|
||||
|
||||
@@ -42,6 +43,10 @@ protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class MHAWOTranspose : public MHA {
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
||||
@@ -16,33 +16,50 @@ namespace snippets {
|
||||
std::string MHA::getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj) {
|
||||
std::vector<ov::PartialShape> inputShapes;
|
||||
bool withMul;
|
||||
ov::element::Type prc;
|
||||
std::string targetDevice;
|
||||
size_t num_nodes, num_subgraphs;
|
||||
std::tie(inputShapes, withMul, num_nodes, num_subgraphs, targetDevice) = obj.param;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
std::tie(inputShapes, withMul, prc, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < inputShapes.size(); ++i)
|
||||
result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({inputShapes[i]}) << "_";
|
||||
result << "Mul=" << withMul << "_";
|
||||
result << "PRC=" << prc << "_";
|
||||
result << "#N=" << num_nodes << "_";
|
||||
result << "#S=" << num_subgraphs << "_";
|
||||
result << "targetDevice=" << targetDevice;
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
|
||||
if (!additionalConfig.empty()) {
|
||||
result << "_PluginConf";
|
||||
for (auto &item : additionalConfig) {
|
||||
if (item.second == InferenceEngine::PluginConfigParams::YES)
|
||||
result << "_" << item.first << "=" << item.second;
|
||||
}
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void MHA::SetUp() {
|
||||
std::vector<ov::PartialShape> inputShapes;
|
||||
bool withMul;
|
||||
std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
ov::element::Type prc;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
|
||||
|
||||
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, withMul);
|
||||
function = f.getOriginal();
|
||||
|
||||
configuration.insert(additionalConfig.begin(), additionalConfig.end());
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
|
||||
setInferenceType(prc);
|
||||
inType = outType = prc;
|
||||
}
|
||||
|
||||
void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
|
||||
@@ -59,16 +76,22 @@ void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticSha
|
||||
void MHASelect::SetUp() {
|
||||
std::vector<ov::PartialShape> inputShapes;
|
||||
bool withMul;
|
||||
std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
ov::element::Type prc;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
|
||||
|
||||
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
|
||||
configuration.insert(additionalConfig.begin(), additionalConfig.end());
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
|
||||
setInferenceType(prc);
|
||||
inType = outType = prc;
|
||||
}
|
||||
|
||||
void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
|
||||
@@ -92,16 +115,41 @@ void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputSta
|
||||
void MHAWOTransposeOnInputs::SetUp() {
|
||||
std::vector<ov::PartialShape> inputShapes;
|
||||
bool withMul;
|
||||
std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam();
|
||||
ov::element::Type prc;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
|
||||
|
||||
auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
|
||||
configuration.insert(additionalConfig.begin(), additionalConfig.end());
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
|
||||
setInferenceType(prc);
|
||||
inType = outType = prc;
|
||||
}
|
||||
|
||||
void MHAWOTranspose::SetUp() {
|
||||
std::vector<ov::PartialShape> inputShapes;
|
||||
bool withMul;
|
||||
ov::element::Type prc;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
|
||||
|
||||
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
|
||||
configuration.insert(additionalConfig.begin(), additionalConfig.end());
|
||||
|
||||
setInferenceType(prc);
|
||||
inType = outType = prc;
|
||||
if (prc == ov::element::bf16)
|
||||
abs_threshold = 0.3;
|
||||
}
|
||||
|
||||
|
||||
@@ -120,6 +168,11 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) {
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAWOTranspose, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
|
||||
@@ -126,6 +126,24 @@ protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* \ /
|
||||
* MatMul0
|
||||
* |
|
||||
* Softmax
|
||||
* \ /
|
||||
* MatMul1
|
||||
* |
|
||||
*/
|
||||
class MHAWOTransposeFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
||||
@@ -343,6 +343,22 @@ std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
|
||||
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
|
||||
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, param1, transA, transB);
|
||||
const auto softmax = std::make_shared<ngraph::opset1::Softmax>(matMul0, 3);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2, transA, transB);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
||||
Reference in New Issue
Block a user