[CPU Tests] migrate matmul test cases to be api 2.0 (#21332)

* [CPU Tests] migrate matmul test cases to be api 2.0

* Update

* Handle convert2OutputVector inplace

---------

Co-authored-by: Vitaliy Urusovskij <vitaliy.urusovskij@intel.com>
This commit is contained in:
River Li 2023-12-05 22:42:47 +08:00 committed by GitHub
parent 65b8bdf892
commit bd315f4b6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 232 additions and 240 deletions

View File

@ -3,14 +3,11 @@
// //
#include "functional_test_utils/skip_tests_config.hpp" #include "functional_test_utils/skip_tests_config.hpp"
#include "openvino/runtime/system_conf.hpp"
#include <ie_system_conf.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ie_parallel.hpp"
std::vector<std::string> disabledTestPatterns() { std::vector<std::string> disabledTestPatterns() {
std::vector<std::string> retVector{ std::vector<std::string> retVector{
// TODO: Issue 31841 // TODO: Issue 31841
@ -314,7 +311,7 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(.*LoadNetworkCompiledKernelsCacheTest.*CanCreateCacheDirAndDumpBinariesUnicodePath.*)"); retVector.emplace_back(R"(.*LoadNetworkCompiledKernelsCacheTest.*CanCreateCacheDirAndDumpBinariesUnicodePath.*)");
#endif #endif
if (!InferenceEngine::with_cpu_x86_avx512_core()) { if (!ov::with_cpu_x86_avx512_core()) {
// on platforms which do not support bfloat16, we are disabling bf16 tests since there are no bf16 primitives, // on platforms which do not support bfloat16, we are disabling bf16 tests since there are no bf16 primitives,
// tests are useless on such platforms // tests are useless on such platforms
retVector.emplace_back(R"(.*(BF|bf)16.*)"); retVector.emplace_back(R"(.*(BF|bf)16.*)");
@ -325,7 +322,7 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(.*Snippets.*(MatMul|Matmul).*)"); retVector.emplace_back(R"(.*Snippets.*(MatMul|Matmul).*)");
} }
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) #if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
if (!InferenceEngine::with_cpu_x86_avx512_core_fp16()) { if (!ov::with_cpu_x86_avx512_core_fp16()) {
// Skip fp16 tests for paltforms that don't support fp16 precision // Skip fp16 tests for paltforms that don't support fp16 precision
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)"); retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
} }
@ -339,7 +336,7 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*EltwiseLayerCPUTest.*IS=\(\[1\.\.10\.2\.5\.6\]_\).*eltwiseOpType=SqDiff.*_configItem=INFERENCE_PRECISION_HINT=f16.*)"); R"(.*EltwiseLayerCPUTest.*IS=\(\[1\.\.10\.2\.5\.6\]_\).*eltwiseOpType=SqDiff.*_configItem=INFERENCE_PRECISION_HINT=f16.*)");
# endif // OV_CPU_ARM_ENABLE_FP16 # endif // OV_CPU_ARM_ENABLE_FP16
#endif #endif
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { if (!ov::with_cpu_x86_avx512_core_vnni() && !ov::with_cpu_x86_avx512_core_amx_int8()) {
// MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions // MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions
retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)"); retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)");
retVector.emplace_back(R"(.*Snippets.*MatMul.*Quantized.*)"); retVector.emplace_back(R"(.*Snippets.*MatMul.*Quantized.*)");
@ -347,11 +344,11 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(.*Snippets.*MHAINT8.*)"); retVector.emplace_back(R"(.*Snippets.*MHAINT8.*)");
retVector.emplace_back(R"(.*Snippets.*MHAQuant.*)"); retVector.emplace_back(R"(.*Snippets.*MHAQuant.*)");
} }
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) if (!ov::with_cpu_x86_avx512_core_amx_int8())
// TODO: Issue 92895 // TODO: Issue 92895
// on platforms which do not support AMX, we are disabling I8 input tests // on platforms which do not support AMX, we are disabling I8 input tests
retVector.emplace_back(R"(smoke_LPT/FakeQuantizeWithNotOptimalTransformation.CompareWithRefImpl.*CPU.*i8.*)"); retVector.emplace_back(R"(smoke_LPT/FakeQuantizeWithNotOptimalTransformation.CompareWithRefImpl.*CPU.*i8.*)");
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_bf16() && !InferenceEngine::with_cpu_x86_bfloat16()) { if (!ov::with_cpu_x86_avx512_core_amx_bf16() && !ov::with_cpu_x86_bfloat16()) {
// ignored for not supported bf16 platforms // ignored for not supported bf16 platforms
retVector.emplace_back(R"(.*smoke_Snippets_EnforcePrecision_bf16.*)"); retVector.emplace_back(R"(.*smoke_Snippets_EnforcePrecision_bf16.*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16.*)"); retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16.*)");

View File

@ -2,17 +2,15 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "test_utils/fusing_test_utils.hpp"
#include "ov_models/builders.hpp" #include "ov_models/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp" #include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/fusing_test_utils.hpp"
#include "transformations/rt_info/decompression.hpp" #include "transformations/rt_info/decompression.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils; using namespace CPUTestUtils;
using namespace ov::test;
namespace SubgraphTestsDefinitions { namespace ov {
namespace test {
/* This test checks MatMul weights constant folding on CPU plugin side and cover two optimizations: /* This test checks MatMul weights constant folding on CPU plugin side and cover two optimizations:
1. Decompressing Convert FP16 -> FP32 CF (FuseFCAndConvertOnWeights in cpu graph optimizer) 1. Decompressing Convert FP16 -> FP32 CF (FuseFCAndConvertOnWeights in cpu graph optimizer)
@ -82,22 +80,21 @@ namespace SubgraphTestsDefinitions {
-------- --------
*/ */
using MatMulDecompressConvertParams = std::tuple< using MatMulDecompressConvertParams = std::tuple<std::vector<InputShape>, // input shapes
std::vector<InputShape>, // input shapes std::pair<bool, bool>, // transposeA, transposeB
std::pair<bool, bool>, // transposeA, transposeB ElementType, // weights precision
ElementType, // weights precision ov::AnyMap, // additional config
std::map<std::string, std::string>, // additional config CPUSpecificParams>;
CPUSpecificParams
>;
class MatMulDecompressConvertTest : public testing::WithParamInterface<MatMulDecompressConvertParams>, class MatMulDecompressConvertTest : public testing::WithParamInterface<MatMulDecompressConvertParams>,
virtual public SubgraphBaseTest, public CPUTestsBase { virtual public SubgraphBaseTest,
public CPUTestsBase {
public: public:
static std::string getTestCaseName(testing::TestParamInfo<MatMulDecompressConvertParams> obj) { static std::string getTestCaseName(testing::TestParamInfo<MatMulDecompressConvertParams> obj) {
std::vector<InputShape> inputShapes; std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose; std::pair<bool, bool> transpose;
ElementType weiElemType; ElementType weiElemType;
std::map<std::string, std::string> additionalConfig; ov::AnyMap additionalConfig;
CPUSpecificParams cpuParams; CPUSpecificParams cpuParams;
std::tie(inputShapes, transpose, weiElemType, additionalConfig, cpuParams) = obj.param; std::tie(inputShapes, transpose, weiElemType, additionalConfig, cpuParams) = obj.param;
@ -124,7 +121,7 @@ public:
result << "config=("; result << "config=(";
for (const auto& configEntry : additionalConfig) { for (const auto& configEntry : additionalConfig) {
result << configEntry.first << ", " << configEntry.second << ":"; result << configEntry.first << ", " << configEntry.second.as<std::string>() << ":";
} }
result << ")"; result << ")";
@ -134,14 +131,14 @@ public:
} }
protected: protected:
template<typename T> template <typename T>
void transposeShape(T& shape) { void transpose_shape(T& shape) {
OPENVINO_ASSERT(shape.size() > 1); OPENVINO_ASSERT(shape.size() > 1);
std::swap(*(shape.end() - 1), *(shape.end() - 2)); std::swap(*(shape.end() - 1), *(shape.end() - 2));
} }
void CheckFCWeightsPrecision(ElementType expectedWeiElemType) const { void check_fc_weights_precision(ElementType expectedWeiElemType) const {
auto getExecValue = [](const ov::Node::RTMap& rtInfo, const std::string &paramName) -> std::string { auto getExecValue = [](const ov::Node::RTMap& rtInfo, const std::string& paramName) -> std::string {
auto it = rtInfo.find(paramName); auto it = rtInfo.find(paramName);
OPENVINO_ASSERT(rtInfo.end() != it); OPENVINO_ASSERT(rtInfo.end() != it);
return it->second.as<std::string>(); return it->second.as<std::string>();
@ -149,10 +146,11 @@ protected:
const auto execFunction = compiledModel.get_runtime_model(); const auto execFunction = compiledModel.get_runtime_model();
ASSERT_NE(nullptr, execFunction); ASSERT_NE(nullptr, execFunction);
for (const auto &fcNode : execFunction->get_ops()) { for (const auto& fcNode : execFunction->get_ops()) {
if (getExecValue(fcNode->get_rt_info(), ExecGraphInfoSerialization::LAYER_TYPE) == "FullyConnected") { if (getExecValue(fcNode->get_rt_info(), ExecGraphInfoSerialization::LAYER_TYPE) == "FullyConnected") {
const auto &constNode = fcNode->get_input_node_shared_ptr(1); const auto& constNode = fcNode->get_input_node_shared_ptr(1);
element::Type expectedType(getExecValue(constNode->get_rt_info(), ExecGraphInfoSerialization::OUTPUT_PRECISIONS)); ov::element::Type expectedType(
getExecValue(constNode->get_rt_info(), ov::exec_model_info::OUTPUT_PRECISIONS));
ASSERT_EQ(expectedType, expectedWeiElemType); ASSERT_EQ(expectedType, expectedWeiElemType);
} }
} }
@ -164,7 +162,7 @@ protected:
std::vector<InputShape> inputShapes; std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose; std::pair<bool, bool> transpose;
ElementType weiConstElemType; ElementType weiConstElemType;
std::map<std::string, std::string> additionalConfig; ov::AnyMap additionalConfig;
CPUSpecificParams cpuParams; CPUSpecificParams cpuParams;
std::tie(inputShapes, transpose, weiConstElemType, additionalConfig, cpuParams) = this->GetParam(); std::tie(inputShapes, transpose, weiConstElemType, additionalConfig, cpuParams) = this->GetParam();
@ -175,19 +173,21 @@ protected:
bool transpA = transpose.first; bool transpA = transpose.first;
bool transpB = transpose.second; bool transpB = transpose.second;
if (transpA) transposeCount++; if (transpA)
if (!transpB) transposeCount++; transposeCount++;
if (!transpB)
transposeCount++;
if (transpA) { if (transpA) {
transposeShape(inputDynamicShapes[0]); transpose_shape(inputDynamicShapes[0]);
for (auto& shapes : targetStaticShapes) { for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[0]); transpose_shape(shapes[0]);
} }
} }
if (transpB) { if (transpB) {
transposeShape(inputDynamicShapes[1]); transpose_shape(inputDynamicShapes[1]);
for (auto& shapes : targetStaticShapes) { for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[1]); transpose_shape(shapes[1]);
} }
} }
@ -198,7 +198,8 @@ protected:
ElementType netType = ElementType::f32; ElementType netType = ElementType::f32;
ElementType convertOutType = ElementType::f32; ElementType convertOutType = ElementType::f32;
if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES) { auto it = additionalConfig.find(ov::hint::inference_precision.name());
if (it != additionalConfig.end() && it->second.as<ov::element::Type>() == ov::element::bf16) {
convertOutType = inType = outType = netType = ElementType::bf16; convertOutType = inType = outType = netType = ElementType::bf16;
weiConstElemType = (weiConstElemType != ElementType::f32) ? weiConstElemType : ElementType::bf16; weiConstElemType = (weiConstElemType != ElementType::f32) ? weiConstElemType : ElementType::bf16;
} else { } else {
@ -209,9 +210,10 @@ protected:
selectedType = makeSelectedTypeStr(selectedType, outType); selectedType = makeSelectedTypeStr(selectedType, outType);
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inType, inShapeA)}; ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inType, inShapeA)};
std::shared_ptr<Node> inputB = builder::makeConstant<float>(weiConstElemType, inShapeB.get_shape(), {}, true); std::shared_ptr<ov::Node> inputB =
ngraph::builder::makeConstant<float>(weiConstElemType, inShapeB.get_shape(), {}, true);
if (weiConstElemType == ElementType::f16) { if (weiConstElemType == ElementType::f16) {
inputB = std::make_shared<opset1::Convert>(inputB, convertOutType); inputB = std::make_shared<ov::op::v0::Convert>(inputB, convertOutType);
mark_as_decompression(inputB); mark_as_decompression(inputB);
} }
expectedWeiConstElemType = weiConstElemType; expectedWeiConstElemType = weiConstElemType;
@ -221,13 +223,13 @@ protected:
function = CPUTestsBase::makeNgraphFunction(netType, params, matMul, cpuNodeType); function = CPUTestsBase::makeNgraphFunction(netType, params, matMul, cpuNodeType);
} }
void CheckExecutionGraph() { void check_execution_graph() {
CheckPluginRelatedResults(compiledModel, "FullyConnected"); CheckPluginRelatedResults(compiledModel, "FullyConnected");
CheckNumberOfNodesWithType(compiledModel, "FullyConnected", fullyConnectedCount); CheckNumberOfNodesWithType(compiledModel, "FullyConnected", fullyConnectedCount);
CheckNumberOfNodesWithType(compiledModel, "Transpose", transposeCount); CheckNumberOfNodesWithType(compiledModel, "Transpose", transposeCount);
CheckNumberOfNodesWithType(compiledModel, "Convert", 0); CheckNumberOfNodesWithType(compiledModel, "Convert", 0);
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0); CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
CheckFCWeightsPrecision(expectedWeiConstElemType); check_fc_weights_precision(expectedWeiConstElemType);
} }
size_t fullyConnectedCount = 1; size_t fullyConnectedCount = 1;
@ -238,7 +240,7 @@ protected:
TEST_P(MatMulDecompressConvertTest, CompareWithRefs) { TEST_P(MatMulDecompressConvertTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED(); SKIP_IF_CURRENT_TEST_IS_DISABLED();
run(); run();
CheckExecutionGraph(); check_execution_graph();
} }
namespace { namespace {
@ -252,41 +254,29 @@ const std::vector<std::pair<bool, bool>> transposeParams = {
const std::vector<std::vector<InputShape>> inputShapes2D = { const std::vector<std::vector<InputShape>> inputShapes2D = {
static_shapes_to_test_representation({{2, 3}, {3, 4}}), static_shapes_to_test_representation({{2, 3}, {3, 4}}),
{ {{{-1, -1}, {{2, 3}, {5, 3}}}, {{3, 4}, {{3, 4}, {3, 4}}}},
{{-1, -1}, {{2, 3}, {5, 3}}},
{{3, 4}, {{3, 4}, {3, 4}}}
},
}; };
const std::vector<std::vector<InputShape>> inputShapes3D = { const std::vector<std::vector<InputShape>> inputShapes3D = {
static_shapes_to_test_representation({{2, 2, 3}, {3, 4}}), static_shapes_to_test_representation({{2, 2, 3}, {3, 4}}),
static_shapes_to_test_representation({{2, 3}, {1, 3, 4}}), static_shapes_to_test_representation({{2, 3}, {1, 3, 4}}),
static_shapes_to_test_representation({{1, 2, 3}, {1, 3, 4}}), static_shapes_to_test_representation({{1, 2, 3}, {1, 3, 4}}),
{ {{{-1, -1, -1}, {{2, 2, 3}, {3, 5, 3}}}, {{3, 4}, {{3, 4}, {3, 4}}}},
{{-1, -1, -1}, {{2, 2, 3}, {3, 5, 3}}}, {{{-1, -1}, {{2, 3}, {5, 3}}}, {{1, 3, 4}, {{1, 3, 4}, {1, 3, 4}}}},
{{3, 4}, {{3, 4}, {3, 4}}} {{{-1, -1, -1}, {{1, 2, 3}, {1, 5, 3}}}, {{1, 3, 4}, {{1, 3, 4}, {1, 3, 4}}}},
},
{
{{-1, -1}, {{2, 3}, {5, 3}}},
{{1, 3, 4}, {{1, 3, 4}, {1, 3, 4}}}
},
{
{{-1, -1, -1}, {{1, 2, 3}, {1, 5, 3}}},
{{1, 3, 4}, {{1, 3, 4}, {1, 3, 4}}}
},
}; };
std::map<std::string, std::string> emptyConfig = {/* empty config */}; ov::AnyMap emptyConfig = {/* empty config */};
std::vector<std::map<std::string, std::string>> filterAdditionalConfig_BF16() { std::vector<ov::AnyMap> filter_additional_config_bf16() {
std::vector<std::map<std::string, std::string>> additionalConfig; std::vector<ov::AnyMap> additionalConfig;
if (with_cpu_x86_avx512_core()) { if (ov::with_cpu_x86_avx512_core()) {
additionalConfig.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}}); additionalConfig.push_back({{ov::hint::inference_precision(ov::element::bf16)}});
} }
return additionalConfig; return additionalConfig;
} }
std::vector<CPUSpecificParams> filterSpecificParams(bool trySetMlas) { std::vector<CPUSpecificParams> filter_specific_params(bool trySetMlas) {
std::vector<CPUSpecificParams> specificParams; std::vector<CPUSpecificParams> specificParams;
if (trySetMlas) { if (trySetMlas) {
#ifdef OV_CPU_WITH_MLAS #ifdef OV_CPU_WITH_MLAS
@ -295,9 +285,9 @@ std::vector<CPUSpecificParams> filterSpecificParams(bool trySetMlas) {
} }
// try set onednn jit params if we can't or shouldn't use mlas // try set onednn jit params if we can't or shouldn't use mlas
if (specificParams.empty()) { if (specificParams.empty()) {
if (with_cpu_x86_avx512_core()) { if (ov::with_cpu_x86_avx512_core()) {
specificParams.push_back(CPUSpecificParams{{}, {}, {"brgemm_avx512"}, "brgemm_avx512"}); specificParams.push_back(CPUSpecificParams{{}, {}, {"brgemm_avx512"}, "brgemm_avx512"});
} else if (with_cpu_x86_avx2()) { } else if (ov::with_cpu_x86_avx2()) {
specificParams.push_back(CPUSpecificParams{{}, {}, {"brgemm_avx2"}, "brgemm_avx2"}); specificParams.push_back(CPUSpecificParams{{}, {}, {"brgemm_avx2"}, "brgemm_avx2"});
} }
} }
@ -305,84 +295,84 @@ std::vector<CPUSpecificParams> filterSpecificParams(bool trySetMlas) {
return specificParams; return specificParams;
} }
std::vector<CPUSpecificParams> filterSpecificParams_BF16() { std::vector<CPUSpecificParams> filter_specific_params_bf16() {
std::vector<CPUSpecificParams> specificParams; std::vector<CPUSpecificParams> specificParams;
specificParams.push_back(CPUSpecificParams{{}, {}, {"jit_gemm"}, "jit_gemm"}); specificParams.push_back(CPUSpecificParams{{}, {}, {"jit_gemm"}, "jit_gemm"});
return specificParams; return specificParams;
} }
const auto testParams2D_FP32_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes2D),
::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f32),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(true)));
const auto testParams2D_FP32_smoke = ::testing::Combine( INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP32,
::testing::ValuesIn(inputShapes2D), MatMulDecompressConvertTest,
::testing::ValuesIn(transposeParams), testParams2D_FP32_smoke,
::testing::Values(ElementType::f32), MatMulDecompressConvertTest::getTestCaseName);
::testing::Values(emptyConfig),
::testing::ValuesIn(filterSpecificParams(true)));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP32, MatMulDecompressConvertTest, testParams2D_FP32_smoke, const auto testParams2D_FP16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes2D),
MatMulDecompressConvertTest::getTestCaseName); ::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(false)));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP16,
MatMulDecompressConvertTest,
testParams2D_FP16_smoke,
MatMulDecompressConvertTest::getTestCaseName);
const auto testParams2D_FP16_smoke = ::testing::Combine( const auto testParams2D_BF16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes2D),
::testing::ValuesIn(inputShapes2D), ::testing::ValuesIn(transposeParams),
::testing::ValuesIn(transposeParams), ::testing::Values(ElementType::f32, ElementType::f16),
::testing::Values(ElementType::f16), ::testing::ValuesIn(filter_additional_config_bf16()),
::testing::Values(emptyConfig), ::testing::ValuesIn(filter_specific_params_bf16()));
::testing::ValuesIn(filterSpecificParams(false)));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP16, MatMulDecompressConvertTest, testParams2D_FP16_smoke, INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_BF16,
MatMulDecompressConvertTest::getTestCaseName); MatMulDecompressConvertTest,
testParams2D_BF16_smoke,
MatMulDecompressConvertTest::getTestCaseName);
const auto testParams3D_FP32_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes3D),
::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f32),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(true)));
const auto testParams2D_BF16_smoke = ::testing::Combine( INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_FP32,
::testing::ValuesIn(inputShapes2D), MatMulDecompressConvertTest,
::testing::ValuesIn(transposeParams), testParams3D_FP32_smoke,
::testing::Values(ElementType::f32, ElementType::f16), MatMulDecompressConvertTest::getTestCaseName);
::testing::ValuesIn(filterAdditionalConfig_BF16()),
::testing::ValuesIn(filterSpecificParams_BF16()));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_BF16, MatMulDecompressConvertTest, testParams2D_BF16_smoke, const auto testParams3D_FP16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes3D),
MatMulDecompressConvertTest::getTestCaseName); ::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(false)));
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_FP16,
MatMulDecompressConvertTest,
testParams3D_FP16_smoke,
MatMulDecompressConvertTest::getTestCaseName);
const auto testParams3D_FP32_smoke = ::testing::Combine( const auto testParams3D_BF16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes3D),
::testing::ValuesIn(inputShapes3D), ::testing::ValuesIn(transposeParams),
::testing::ValuesIn(transposeParams), ::testing::Values(ElementType::f32, ElementType::f16),
::testing::Values(ElementType::f32), ::testing::ValuesIn(filter_additional_config_bf16()),
::testing::Values(emptyConfig), ::testing::ValuesIn(filter_specific_params_bf16()));
::testing::ValuesIn(filterSpecificParams(true)));
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_FP32, MatMulDecompressConvertTest, testParams3D_FP32_smoke, INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_BF16,
MatMulDecompressConvertTest::getTestCaseName); MatMulDecompressConvertTest,
testParams3D_BF16_smoke,
MatMulDecompressConvertTest::getTestCaseName);
} // namespace
const auto testParams3D_FP16_smoke = ::testing::Combine(
::testing::ValuesIn(inputShapes3D),
::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filterSpecificParams(false)));
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_FP16, MatMulDecompressConvertTest, testParams3D_FP16_smoke,
MatMulDecompressConvertTest::getTestCaseName);
const auto testParams3D_BF16_smoke = ::testing::Combine(
::testing::ValuesIn(inputShapes3D),
::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f32, ElementType::f16),
::testing::ValuesIn(filterAdditionalConfig_BF16()),
::testing::ValuesIn(filterSpecificParams_BF16()));
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_BF16, MatMulDecompressConvertTest, testParams3D_BF16_smoke,
MatMulDecompressConvertTest::getTestCaseName);
} // namespace
/* In case of Convert has 2 or more consumers there is a problem with memory allocation in CPU plug-in (see Edge::init() /* In case of Convert has 2 or more consumers there is a problem with memory allocation in CPU plug-in (see Edge::init()
method). Maybe we can just remove the check (edgePtr->getParent()->isConstant() && !edgePtr->getChild()->isConstant()) method). Maybe we can just remove the check (edgePtr->getParent()->isConstant() && !edgePtr->getChild()->isConstant())
and everything will be OK, But this solution should be additionally checked. For now, for these cases we will not be and everything will be OK, But this solution should be additionally checked. For now, for these cases we will not be
doing CF on the CPU side and it should be done on the ngraph side. doing CF on the CPU side and it should be done on the graph side.
* Graph before: * Graph before:
------------ ------------ ------------ ------------ ------------ ------------
@ -422,13 +412,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_BF16, MatMulDecompressConvertTest, testPara
|Output| |Output|
-------- --------
*/ */
using MatMulDecompressConvertParams2 = std::tuple< using MatMulDecompressConvertParams2 = std::tuple<std::vector<InputShape>, // input shapes
std::vector<InputShape>, // input shapes std::pair<bool, bool>, // transposeA, transposeB
std::pair<bool, bool>, // transposeA, transposeB ElementType, // weights precision
ElementType, // weights precision ov::AnyMap, // additional property
std::map<std::string, std::string>, // additional config CPUSpecificParams>;
CPUSpecificParams
>;
class MatMulDecompressConvertTest2 : public MatMulDecompressConvertTest { class MatMulDecompressConvertTest2 : public MatMulDecompressConvertTest {
protected: protected:
@ -438,7 +426,7 @@ protected:
std::vector<InputShape> inputShapes; std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose; std::pair<bool, bool> transpose;
ElementType weiConstElemType; ElementType weiConstElemType;
std::map<std::string, std::string> additionalConfig; ov::AnyMap additionalConfig;
CPUSpecificParams cpuParams; CPUSpecificParams cpuParams;
std::tie(inputShapes, transpose, weiConstElemType, additionalConfig, cpuParams) = this->GetParam(); std::tie(inputShapes, transpose, weiConstElemType, additionalConfig, cpuParams) = this->GetParam();
@ -450,23 +438,25 @@ protected:
bool transpB = transpose.second; bool transpB = transpose.second;
fullyConnectedCount = 2; fullyConnectedCount = 2;
if (transpA) transposeCount += 2; if (transpA)
if (!transpB) transposeCount++; transposeCount += 2;
if (!transpB)
transposeCount++;
if (transpA) { if (transpA) {
transposeShape(inputDynamicShapes[0]); transpose_shape(inputDynamicShapes[0]);
for (auto& shapes : targetStaticShapes) { for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[0]); transpose_shape(shapes[0]);
} }
transposeShape(inputDynamicShapes[1]); transpose_shape(inputDynamicShapes[1]);
for (auto& shapes : targetStaticShapes) { for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[1]); transpose_shape(shapes[1]);
} }
} }
if (transpB) { if (transpB) {
transposeShape(inputDynamicShapes[2]); transpose_shape(inputDynamicShapes[2]);
for (auto& shapes : targetStaticShapes) { for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[2]); transpose_shape(shapes[2]);
} }
} }
@ -478,7 +468,8 @@ protected:
ElementType netType = ElementType::f32; ElementType netType = ElementType::f32;
ElementType convertOutType = ElementType::f32; ElementType convertOutType = ElementType::f32;
if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES) { auto it = additionalConfig.find(ov::hint::inference_precision.name());
if (it != additionalConfig.end() && it->second.as<ov::element::Type>() == ov::element::bf16) {
convertOutType = inType = outType = netType = ElementType::bf16; convertOutType = inType = outType = netType = ElementType::bf16;
weiConstElemType = (weiConstElemType != ElementType::f32) ? weiConstElemType : ElementType::bf16; weiConstElemType = (weiConstElemType != ElementType::f32) ? weiConstElemType : ElementType::bf16;
} else { } else {
@ -492,12 +483,13 @@ protected:
for (auto&& shape : {inShapeFC0, inShapeFC1}) { for (auto&& shape : {inShapeFC0, inShapeFC1}) {
params.push_back(std::make_shared<ov::op::v0::Parameter>(inType, shape)); params.push_back(std::make_shared<ov::op::v0::Parameter>(inType, shape));
} }
std::shared_ptr<Node> inputWeights = builder::makeConstant<float>(weiConstElemType, inShapeWeights.get_shape(), {}, true); std::shared_ptr<ov::Node> inputWeights =
ngraph::builder::makeConstant<float>(weiConstElemType, inShapeWeights.get_shape(), {}, true);
if (weiConstElemType == ElementType::f16) { if (weiConstElemType == ElementType::f16) {
inputWeights = std::make_shared<opset1::Convert>(inputWeights, convertOutType); inputWeights = std::make_shared<ov::op::v0::Convert>(inputWeights, convertOutType);
mark_as_decompression(inputWeights); mark_as_decompression(inputWeights);
} }
// In this test, convert must be folded on the ngraph side, so the constant with fp32 precision is expected // In this test, convert must be folded on the graph side, so the constant with fp32 precision is expected
expectedWeiConstElemType = ElementType::f32; expectedWeiConstElemType = ElementType::f32;
auto matMul0 = std::make_shared<ov::op::v0::MatMul>(params[0], inputWeights, transpA, transpB); auto matMul0 = std::make_shared<ov::op::v0::MatMul>(params[0], inputWeights, transpA, transpB);
@ -512,21 +504,24 @@ protected:
TEST_P(MatMulDecompressConvertTest2, CompareWithRefs) { TEST_P(MatMulDecompressConvertTest2, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED(); SKIP_IF_CURRENT_TEST_IS_DISABLED();
run(); run();
CheckExecutionGraph(); check_execution_graph();
} }
namespace { namespace {
const auto testParams2D_FP16_2_smoke = ::testing::Combine( const auto testParams2D_FP16_2_smoke =
::testing::Values(static_shapes_to_test_representation({{2, 3}, {2, 3}, {3, 4}})), ::testing::Combine(::testing::Values(static_shapes_to_test_representation({{2, 3}, {2, 3}, {3, 4}})),
::testing::Values(std::pair<bool, bool>{false, true}), ::testing::Values(std::pair<bool, bool>{false, true}),
::testing::Values(ElementType::f16), ::testing::Values(ElementType::f16),
::testing::Values(emptyConfig), ::testing::Values(emptyConfig),
::testing::ValuesIn(filterSpecificParams(true))); ::testing::ValuesIn(filter_specific_params(true)));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP16_2, MatMulDecompressConvertTest2, testParams2D_FP16_2_smoke, INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP16_2,
MatMulDecompressConvertTest2::getTestCaseName); MatMulDecompressConvertTest2,
testParams2D_FP16_2_smoke,
MatMulDecompressConvertTest2::getTestCaseName);
} // namespace } // namespace
} // namespace SubgraphTestsDefinitions } // namespace test
} // namespace ov

View File

@ -6,18 +6,18 @@
#include "test_utils/fusing_test_utils.hpp" #include "test_utils/fusing_test_utils.hpp"
#include "ov_models/builders.hpp" #include "ov_models/builders.hpp"
#include "common_test_utils/common_utils.hpp" #include "common_test_utils/common_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils; using namespace CPUTestUtils;
namespace SubgraphTestsDefinitions { namespace ov {
namespace test {
using ElementType = ov::element::Type_t; using ElementType = ov::element::Type_t;
using MatmulBrgemmInt8TestParams = std::tuple<SizeVector, // input shape using MatmulBrgemmInt8TestParams = std::tuple<ov::Shape, // input shape
bool, // true: FullyConnected false: Matmul bool, // true: FullyConnected false: Matmul
ElementType, // input u8/s8 ElementType, // input u8/s8
ElementType, // output f32/u8/s8 ElementType, // output f32/u8/s8
@ -30,10 +30,10 @@ using MatmulBrgemmInt8TestParams = std::tuple<SizeVector, // input shape
// (u8/s8 + s8)->f32 // (u8/s8 + s8)->f32
// (u8/s8 + s8)->u8/s8 // (u8/s8 + s8)->u8/s8
class MatmulBrgemmInt8Test : public testing::WithParamInterface<MatmulBrgemmInt8TestParams>, public CpuTestWithFusing, class MatmulBrgemmInt8Test : public testing::WithParamInterface<MatmulBrgemmInt8TestParams>, public CpuTestWithFusing,
virtual public LayerTestsUtils::LayerTestsCommon { virtual public ov::test::SubgraphBaseStaticTest {
public: public:
static std::string getTestCaseName(testing::TestParamInfo<MatmulBrgemmInt8TestParams> obj) { static std::string getTestCaseName(testing::TestParamInfo<MatmulBrgemmInt8TestParams> obj) {
SizeVector supportedInputShapes; ov::Shape supportedInputShapes;
bool isFC; bool isFC;
ElementType inType; ElementType inType;
ElementType outType; ElementType outType;
@ -41,7 +41,7 @@ public:
std::tie(supportedInputShapes, isFC, inType, outType, cpuParams) = obj.param; std::tie(supportedInputShapes, isFC, inType, outType, cpuParams) = obj.param;
std::ostringstream result; std::ostringstream result;
result << "IS=" << ov::test::utils::vec2str(supportedInputShapes) << "_"; result << "IS=" << supportedInputShapes.to_string() << "_";
result << (isFC ? "FullyConnected" : "MatMul") << "_"; result << (isFC ? "FullyConnected" : "MatMul") << "_";
result << "InputType=" << inType << "_"; result << "InputType=" << inType << "_";
result << "OutputType=" << outType << "_"; result << "OutputType=" << outType << "_";
@ -57,16 +57,16 @@ protected:
ElementType outType; ElementType outType;
void SetUp() override { void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU; targetDevice = ov::test::utils::DEVICE_CPU;
SizeVector inShapes; ov::Shape inShapes;
CPUSpecificParams cpuParams; CPUSpecificParams cpuParams;
std::tie(inShapes, isFC, inType, outType, cpuParams) = this->GetParam(); std::tie(inShapes, isFC, inType, outType, cpuParams) = this->GetParam();
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
const auto ngPrec = element::f32; const auto ngPrec = ov::element::f32;
ov::ParameterVector inputParams {std::make_shared<ov::op::v0::Parameter>(ngPrec, ov::Shape(inShapes))}; ov::ParameterVector inputParams {std::make_shared<ov::op::v0::Parameter>(ngPrec, ov::Shape(inShapes))};
std::shared_ptr<Node> fq1; std::shared_ptr<ov::Node> fq1;
std::shared_ptr<Node> matMul; std::shared_ptr<ov::Node> matMul;
std::shared_ptr<Node> nodeBeforeConv; std::shared_ptr<ov::Node> nodeBeforeConv;
selectedType = makeSelectedTypeStr(selectedType, ElementType::i8); selectedType = makeSelectedTypeStr(selectedType, ElementType::i8);
if (inType == ElementType::u8) if (inType == ElementType::u8)
fq1 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f}); fq1 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
@ -74,15 +74,15 @@ protected:
fq1 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); fq1 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
if (isFC) { if (isFC) {
ngraph::Shape weightShape = inShapes; ov::Shape weightShape = inShapes;
std::swap(weightShape[0], weightShape[1]); std::swap(weightShape[0], weightShape[1]);
auto weightsNode = ngraph::builder::makeConstant(ngPrec, weightShape, std::vector<float>{0.0f}, true); auto weightsNode = ngraph::builder::makeConstant(ngPrec, weightShape, std::vector<float>{0.0f}, true);
auto fq2 = ngraph::builder::makeFakeQuantize(weightsNode, ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); auto fq2 = ngraph::builder::makeFakeQuantize(weightsNode, ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
auto fc = std::make_shared<ngraph::opset1::MatMul>(fq1, fq2, false, false); auto fc = std::make_shared<ov::op::v0::MatMul>(fq1, fq2, false, false);
fc->get_rt_info() = getCPUInfo(); fc->get_rt_info() = getCPUInfo();
fc->set_friendly_name(nameMatmul); fc->set_friendly_name(nameMatmul);
auto biasWeightsNode = ngraph::builder::makeConstant(ngPrec, {}, std::vector<float>{0.0f}, true); auto biasWeightsNode = ngraph::builder::makeConstant(ngPrec, {}, std::vector<float>{0.0f}, true);
matMul = std::make_shared<ngraph::opset1::Add>(fc, biasWeightsNode); matMul = std::make_shared<ov::op::v1::Add>(fc, biasWeightsNode);
} else { } else {
auto fq2 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); auto fq2 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
matMul = std::make_shared<ov::op::v0::MatMul>(fq1, fq2, false, true); matMul = std::make_shared<ov::op::v0::MatMul>(fq1, fq2, false, true);
@ -98,7 +98,7 @@ protected:
// matmul->fq->matmul can cover x8*s8->x8 case // matmul->fq->matmul can cover x8*s8->x8 case
auto filterWeightsShape = matMul->get_output_shape(0); auto filterWeightsShape = matMul->get_output_shape(0);
auto filterWeightsNode = ngraph::builder::makeConstant(element::f32, filterWeightsShape, std::vector<float>{}, true); auto filterWeightsNode = ngraph::builder::makeConstant(ov::element::f32, filterWeightsShape, std::vector<float>{}, true);
auto fq3 = ngraph::builder::makeFakeQuantize(filterWeightsNode, ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}); auto fq3 = ngraph::builder::makeFakeQuantize(filterWeightsNode, ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
// only matmul avx2 support s8*s8 input // only matmul avx2 support s8*s8 input
auto matMul2 = std::make_shared<ov::op::v0::MatMul>(nodeBeforeConv, fq3, false, false); auto matMul2 = std::make_shared<ov::op::v0::MatMul>(nodeBeforeConv, fq3, false, false);
@ -106,7 +106,7 @@ protected:
function = makeNgraphFunction(ngPrec, inputParams, matMul2, "MatmulBrgemmInt8"); function = makeNgraphFunction(ngPrec, inputParams, matMul2, "MatmulBrgemmInt8");
} }
void CheckNode(std::shared_ptr<const ov::Model> function, const std::string& nodeName) { void check_node(std::shared_ptr<const ov::Model> function, const std::string& nodeName) {
ASSERT_NE(nullptr, function); ASSERT_NE(nullptr, function);
for (const auto &node : function->get_ops()) { for (const auto &node : function->get_ops()) {
const auto & rtInfo = node->get_rt_info(); const auto & rtInfo = node->get_rt_info();
@ -127,18 +127,17 @@ protected:
TEST_P(MatmulBrgemmInt8Test, CompareWithRefs) { TEST_P(MatmulBrgemmInt8Test, CompareWithRefs) {
// only cover avx2_vnni // only cover avx2_vnni
if (InferenceEngine::with_cpu_x86_avx512_core() || !InferenceEngine::with_cpu_x86_avx2_vnni()) if (ov::with_cpu_x86_avx512_core() || !ov::with_cpu_x86_avx2_vnni())
GTEST_SKIP(); GTEST_SKIP();
Run(); run();
InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo(); auto exec = compiledModel.get_runtime_model();
auto exec = execGraphInfo.getFunction(); check_node(exec, nameMatmul);
CheckNode(exec, nameMatmul);
} }
namespace { namespace {
const std::vector<SizeVector> supportedInputShapes = { const std::vector<ov::Shape> supportedInputShapes = {
{16, 32}, {16, 32},
{17, 15}, {17, 15},
}; };
@ -148,7 +147,8 @@ const std::vector<CPUSpecificParams>matmulSpecificFilterParams = {
{{}, {}, {"jit_gemm"}, "jit_gemm"} {{}, {}, {"jit_gemm"}, "jit_gemm"}
}; };
INSTANTIATE_TEST_SUITE_P(smoke_matmulBrgemmInt8, MatmulBrgemmInt8Test, INSTANTIATE_TEST_SUITE_P(smoke_matmulBrgemmInt8,
MatmulBrgemmInt8Test,
::testing::Combine(::testing::ValuesIn(supportedInputShapes), ::testing::Combine(::testing::ValuesIn(supportedInputShapes),
::testing::ValuesIn({true, false}), ::testing::ValuesIn({true, false}),
::testing::ValuesIn({ElementType::u8, ElementType::i8}), ::testing::ValuesIn({ElementType::u8, ElementType::i8}),
@ -156,6 +156,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_matmulBrgemmInt8, MatmulBrgemmInt8Test,
::testing::ValuesIn(matmulSpecificFilterParams)), ::testing::ValuesIn(matmulSpecificFilterParams)),
MatmulBrgemmInt8Test::getTestCaseName); MatmulBrgemmInt8Test::getTestCaseName);
} // namespace } // namespace
} // namespace SubgraphTestsDefinitions } // namespace test
} // namespace ov

View File

@ -2,59 +2,62 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "test_utils/cpu_test_utils.hpp"
#include "ov_models/builders.hpp" #include "ov_models/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils; using namespace CPUTestUtils;
namespace SubgraphTestsDefinitions { namespace ov {
namespace test {
using MatmulStridedInputsOutputsTestParams = Precision; using MatmulStridedInputsOutputsTestParams = ov::element::Type;
class MatmulStridedInputsOutputsTest : public testing::WithParamInterface<MatmulStridedInputsOutputsTestParams>, class MatmulStridedInputsOutputsTest : public testing::WithParamInterface<MatmulStridedInputsOutputsTestParams>,
public CPUTestsBase, public CPUTestsBase,
virtual public LayerTestsUtils::LayerTestsCommon { virtual public SubgraphBaseStaticTest {
public: public:
static std::string getTestCaseName(testing::TestParamInfo<MatmulStridedInputsOutputsTestParams> obj) { static std::string getTestCaseName(testing::TestParamInfo<MatmulStridedInputsOutputsTestParams> obj) {
Precision netPrecision; ov::element::Type netPrecision;
netPrecision = obj.param; netPrecision = obj.param;
std::ostringstream result; std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_"; result << "netPRC=" << netPrecision.to_string() << "_";
return result.str(); return result.str();
} }
protected: protected:
void SetUp() override { void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU; targetDevice = utils::DEVICE_CPU;
Precision netPrecision; const auto ngPrec = this->GetParam();
netPrecision = this->GetParam();
const auto ngPrec = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
SizeVector splitShape{1, 2, 1, 16}; ov::Shape splitShape{1, 2, 1, 16};
ov::ParameterVector splitInputParams {std::make_shared<ov::op::v0::Parameter>(ngPrec, ov::Shape(splitShape))}; ov::ParameterVector splitInputParams {std::make_shared<ov::op::v0::Parameter>(ngPrec, ov::Shape(splitShape))};
auto split_axis_op = std::make_shared<ov::op::v0::Constant>(ov::element::Type_t::i64, ov::Shape{}, std::vector<int64_t>{1}); auto split_axis_op = std::make_shared<ov::op::v0::Constant>(ov::element::Type_t::i64, ov::Shape{}, std::vector<int64_t>{1});
auto split = std::make_shared<ov::op::v1::Split>(splitInputParams[0], split_axis_op, 2); auto split = std::make_shared<ov::op::v1::Split>(splitInputParams[0], split_axis_op, 2);
std::vector<ov::Shape> concatShapes{{1, 1, 8, 8}, {1, 1, 8, 8}}; std::vector<ov::Shape> concatShapes{{1, 1, 8, 8}, {1, 1, 8, 8}};
ov::ParameterVector concatInputParams {std::make_shared<ov::op::v0::Parameter>(ngPrec, concatShapes[0]), ov::ParameterVector concatInputParams{std::make_shared<ov::op::v0::Parameter>(ngPrec, concatShapes[0]),
std::make_shared<ov::op::v0::Parameter>(ngPrec, concatShapes[1])}; std::make_shared<ov::op::v0::Parameter>(ngPrec, concatShapes[1])};
const auto concatOutputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(concatInputParams)); ov::OutputVector concatOutputNodes;
for (auto&& node : concatInputParams) {
for (auto&& param : node->outputs())
concatOutputNodes.push_back(param);
}
const auto concat = std::make_shared<ov::op::v0::Concat>(concatOutputNodes, 2); const auto concat = std::make_shared<ov::op::v0::Concat>(concatOutputNodes, 2);
const auto matMul1 = std::make_shared<ov::op::v0::MatMul>(split->output(0), concat, false, false); const auto matMul1 = std::make_shared<ov::op::v0::MatMul>(split->output(0), concat, false, false);
SizeVector matmulShape{1, 1, 16, 8}; ov::Shape matmulShape{1, 1, 16, 8};
ov::ParameterVector matmulInputParams {std::make_shared<ov::op::v0::Parameter>(ngPrec, ov::Shape(matmulShape))}; ov::ParameterVector matmulInputParams {std::make_shared<ov::op::v0::Parameter>(ngPrec, ov::Shape(matmulShape))};
const auto matMul2 = std::make_shared<ov::op::v0::MatMul>(split->output(1), matmulInputParams[0], false, false); const auto matMul2 = std::make_shared<ov::op::v0::MatMul>(split->output(1), matmulInputParams[0], false, false);
const auto concatMatMuls = std::make_shared<ov::op::v0::Concat>(ov::NodeVector{matMul1, matMul2}, 2 /* 3rd axis */); const auto concatMatMuls = std::make_shared<ov::op::v0::Concat>(ov::NodeVector{matMul1, matMul2}, 2 /* 3rd axis */);
ngraph::ParameterVector inputParams = {splitInputParams[0], concatInputParams[0], concatInputParams[1], matmulInputParams[0]}; ov::ParameterVector inputParams = {splitInputParams[0], concatInputParams[0], concatInputParams[1], matmulInputParams[0]};
function = makeNgraphFunction(ngPrec, inputParams, concatMatMuls, "MatmulStridedInputsOutputs"); function = makeNgraphFunction(ngPrec, inputParams, concatMatMuls, "MatmulStridedInputsOutputs");
} }
}; };
@ -84,16 +87,17 @@ protected:
*/ */
TEST_P(MatmulStridedInputsOutputsTest, CompareWithRefs) { TEST_P(MatmulStridedInputsOutputsTest, CompareWithRefs) {
Run(); run();
} }
namespace { namespace {
INSTANTIATE_TEST_SUITE_P(smoke_Check, MatmulStridedInputsOutputsTest, INSTANTIATE_TEST_SUITE_P(smoke_Check,
::testing::Values(Precision::FP32, MatmulStridedInputsOutputsTest,
Precision::BF16), ::testing::Values(ov::element::f32, ov::element::bf16),
MatmulStridedInputsOutputsTest::getTestCaseName); MatmulStridedInputsOutputsTest::getTestCaseName);
} // namespace } // namespace
} // namespace SubgraphTestsDefinitions } // namespace test
} // namespace ov

View File

@ -2,17 +2,15 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "test_utils/fusing_test_utils.hpp"
#include "ov_models/builders.hpp" #include "ov_models/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp" #include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/fusing_test_utils.hpp"
#include "transformations/rt_info/decompression.hpp" #include "transformations/rt_info/decompression.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils; using namespace CPUTestUtils;
using namespace ov::test;
namespace SubgraphTestsDefinitions { namespace ov {
namespace test {
/* /*
* WP - weights precision * WP - weights precision
* DP - decompression precision * DP - decompression precision
@ -58,7 +56,7 @@ using MatmulWeightsDecompressionParams = std::tuple<ShapeParams,
bool, // transpose on weights bool, // transpose on weights
bool, // decompression subtract bool, // decompression subtract
bool, // reshape on decompression constants bool, // reshape on decompression constants
std::map<std::string, std::string>, // additional config ov::AnyMap, // additional config
fusingSpecificParams, fusingSpecificParams,
bool>; // should use decompression implementation bool>; // should use decompression implementation
@ -73,7 +71,7 @@ public:
bool transpose; bool transpose;
bool decompression_sub; bool decompression_sub;
bool reshape_on_decompression; bool reshape_on_decompression;
std::map<std::string, std::string> additional_config; ov::AnyMap additional_config;
fusingSpecificParams fusing_params; fusingSpecificParams fusing_params;
bool should_fuse; bool should_fuse;
@ -99,7 +97,7 @@ public:
result << "config=("; result << "config=(";
for (const auto& configEntry : additional_config) { for (const auto& configEntry : additional_config) {
result << configEntry.first << ", " << configEntry.second << ":"; result << configEntry.first << ", " << configEntry.second.as<std::string>() << ":";
} }
result << ")"; result << ")";
result << CpuTestWithFusing::getTestCaseName(fusing_params); result << CpuTestWithFusing::getTestCaseName(fusing_params);
@ -145,7 +143,7 @@ protected:
auto weights = ngraph::builder::makeConstant<int8_t>(weights_precision, transformed_weights_shape, {}, true, 7); auto weights = ngraph::builder::makeConstant<int8_t>(weights_precision, transformed_weights_shape, {}, true, 7);
weights->set_friendly_name("Compressed_weights"); weights->set_friendly_name("Compressed_weights");
auto weights_convert = std::make_shared<ngraph::opset1::Convert>(weights, decompression_precision); auto weights_convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
std::shared_ptr<ov::Node> mul_parent = weights_convert; std::shared_ptr<ov::Node> mul_parent = weights_convert;
auto output_channels = *weights_shape.rbegin(); auto output_channels = *weights_shape.rbegin();
@ -166,7 +164,7 @@ protected:
scaleshift_const_shape.erase(std::remove(scaleshift_const_shape.begin(), scaleshift_const_shape.end(), 1), scaleshift_const_shape.end()); scaleshift_const_shape.erase(std::remove(scaleshift_const_shape.begin(), scaleshift_const_shape.end(), 1), scaleshift_const_shape.end());
if (add_subtract) { if (add_subtract) {
auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true, 7); auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true, 7);
std::shared_ptr<ov::Node> shift_convert = std::make_shared<ngraph::opset1::Convert>(shift_const, decompression_precision); std::shared_ptr<ov::Node> shift_convert = std::make_shared<ov::op::v0::Convert>(shift_const, decompression_precision);
if (reshape_on_decompression_constant) { if (reshape_on_decompression_constant) {
auto shift_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape); auto shift_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape);
auto shift_reshape = std::make_shared<ov::opset10::Reshape>(shift_convert, shift_reshape_const, false); auto shift_reshape = std::make_shared<ov::opset10::Reshape>(shift_convert, shift_reshape_const, false);
@ -234,7 +232,7 @@ protected:
bool transpose_weights; bool transpose_weights;
bool decompression_sub; bool decompression_sub;
bool reshape_on_decompression; bool reshape_on_decompression;
std::map<std::string, std::string> additional_config; ov::AnyMap additional_config;
fusingSpecificParams fusing_params; fusingSpecificParams fusing_params;
bool should_fuse; bool should_fuse;
@ -252,7 +250,7 @@ protected:
std::tie(postOpMgrPtr, fusedOps) = fusing_params; std::tie(postOpMgrPtr, fusedOps) = fusing_params;
init_input_shapes({shape_params.data_shape, {{}, {{shape_params.weights_shape}}}}); init_input_shapes({shape_params.data_shape, {{}, {{shape_params.weights_shape}}}});
ElementType netType = element::f32; ElementType netType = ov::element::f32;
inType = outType = netType; inType = outType = netType;
function = initSubgraph(inputDynamicShapes[0], function = initSubgraph(inputDynamicShapes[0],
@ -266,7 +264,7 @@ protected:
reshape_on_decompression); reshape_on_decompression);
} }
void checkResults() { void check_results() {
const auto& test_param = GetParam(); const auto& test_param = GetParam();
const auto& weights_precision = std::get<1>(test_param); const auto& weights_precision = std::get<1>(test_param);
@ -290,19 +288,19 @@ protected:
TEST_P(MatmulWeightsDecompression, CompareWithRefs) { TEST_P(MatmulWeightsDecompression, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()
run(); run();
checkResults(); check_results();
} }
namespace { namespace {
std::vector<std::map<std::string, std::string>> filterAdditionalConfigBasic() { std::vector<ov::AnyMap> filter_additional_config_basic() {
std::vector<std::map<std::string, std::string>> additional_config = {CPUTestUtils::cpuEmptyPluginConfig}; std::vector<ov::AnyMap> additional_config = {CPUTestUtils::empty_plugin_config};
return additional_config; return additional_config;
} }
std::vector<std::map<std::string, std::string>> filterAdditionalConfigAMX() { std::vector<ov::AnyMap> filter_additional_config_amx() {
std::vector<std::map<std::string, std::string>> additional_config = {}; std::vector<ov::AnyMap> additional_config = {};
if (with_cpu_x86_avx512_core_amx()) if (ov::with_cpu_x86_avx512_core_amx())
additional_config.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}}); additional_config.push_back({{ov::hint::inference_precision(ov::element::bf16)}});
return additional_config; return additional_config;
} }
@ -331,11 +329,7 @@ const std::vector<ShapeParams> input_shapes_amx = {
{{{}, {{11, 339, 577}}}, {577, 335}}, {{{}, {{11, 339, 577}}}, {577, 335}},
{{{}, {{1, 1, 256}}}, {256, 128}, 64ul}, {{{}, {{1, 1, 256}}}, {256, 128}, 64ul},
}; };
const std::vector<fusingSpecificParams> fusing_params { const std::vector<fusingSpecificParams> fusing_params{emptyFusingSpec, fusingBias, fusingFakeQuantizePerTensorRelu};
emptyFusingSpec,
fusingBias,
fusingFakeQuantizePerTensorRelu
};
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic, INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
MatmulWeightsDecompression, MatmulWeightsDecompression,
@ -345,7 +339,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
::testing::Values(true), ::testing::Values(true),
::testing::Values(true), ::testing::Values(true),
::testing::Values(true), ::testing::Values(true),
::testing::ValuesIn(filterAdditionalConfigBasic()), ::testing::ValuesIn(filter_additional_config_basic()),
::testing::ValuesIn(fusing_params), ::testing::ValuesIn(fusing_params),
::testing::Values(true)), ::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName); MatmulWeightsDecompression::getTestCaseName);
@ -358,7 +352,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_amx,
::testing::Values(true), ::testing::Values(true),
::testing::Values(true), ::testing::Values(true),
::testing::Values(true), ::testing::Values(true),
::testing::ValuesIn(filterAdditionalConfigAMX()), ::testing::ValuesIn(filter_additional_config_amx()),
::testing::ValuesIn(fusing_params), ::testing::ValuesIn(fusing_params),
::testing::Values(true)), ::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName); MatmulWeightsDecompression::getTestCaseName);
@ -387,7 +381,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_basic,
::testing::ValuesIn(transpose_weights), ::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub), ::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression), ::testing::ValuesIn(reshape_on_decompression),
::testing::ValuesIn(filterAdditionalConfigBasic()), ::testing::ValuesIn(filter_additional_config_basic()),
::testing::Values(emptyFusingSpec), ::testing::Values(emptyFusingSpec),
::testing::Values(true)), ::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName); MatmulWeightsDecompression::getTestCaseName);
@ -400,9 +394,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_amx,
::testing::ValuesIn(transpose_weights), ::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub), ::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression), ::testing::ValuesIn(reshape_on_decompression),
::testing::ValuesIn(filterAdditionalConfigAMX()), ::testing::ValuesIn(filter_additional_config_amx()),
::testing::Values(emptyFusingSpec), ::testing::Values(emptyFusingSpec),
::testing::Values(true)), ::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName); MatmulWeightsDecompression::getTestCaseName);
} // namespace } // namespace
} // namespace SubgraphTestsDefinitions } // namespace test
} // namespace ov

View File

@ -155,9 +155,9 @@ protected:
* @param lastNode The last node of the initial graph. * @param lastNode The last node of the initial graph.
* @return The last node of the modified graph. * @return The last node of the modified graph.
*/ */
virtual std::shared_ptr<ov::Node> modifyGraph(const ov::element::Type &ngPrc, virtual std::shared_ptr<ov::Node> modifyGraph(const ov::element::Type& ngPrc,
ov::ParameterVector &params, ov::ParameterVector& params,
const std::shared_ptr<ov::Node> &lastNode); const std::shared_ptr<ov::Node>& lastNode);
virtual bool primTypeCheck(std::string primType) const; virtual bool primTypeCheck(std::string primType) const;