[CPU] Enable oneDNN avx2 brgemm impls for Matmul/FullyConnected operations (#18467)
This commit is contained in:
parent
5a945a6219
commit
5f5df36b60
@ -103,6 +103,7 @@ tolerance_map = {
|
||||
"resnet152v2": {"atol": 1e-05, "rtol": 0.001},
|
||||
"resnet18v2": {"atol": 1e-05, "rtol": 0.001},
|
||||
"resnet34v2": {"atol": 1e-05, "rtol": 0.001},
|
||||
"resnet34-v1-7": {"atol": 1e-06, "rtol": 0.001},
|
||||
"vgg16": {"atol": 1e-05, "rtol": 0.001},
|
||||
"vgg19-bn": {"atol": 1e-05, "rtol": 0.001},
|
||||
"test_tiny_yolov2": {"atol": 1e-05, "rtol": 0.001},
|
||||
|
@ -94,6 +94,13 @@ using ov::with_cpu_x86_avx;
|
||||
*/
|
||||
using ov::with_cpu_x86_avx2;
|
||||
|
||||
/**
|
||||
* @brief Checks whether CPU supports AVX2_VNNI capability
|
||||
* @ingroup ie_dev_api_system_conf
|
||||
* @return `True` is AVX2_VNNI instructions are available, `false` otherwise
|
||||
*/
|
||||
using ov::with_cpu_x86_avx2_vnni;
|
||||
|
||||
/**
|
||||
* @brief Checks whether CPU supports AVX 512 capability
|
||||
* @ingroup ie_dev_api_system_conf
|
||||
|
@ -82,6 +82,13 @@ OPENVINO_RUNTIME_API bool with_cpu_x86_avx();
|
||||
*/
|
||||
OPENVINO_RUNTIME_API bool with_cpu_x86_avx2();
|
||||
|
||||
/**
|
||||
* @brief Checks whether CPU supports AVX2_VNNI capability
|
||||
* @ingroup ov_dev_api_system_conf
|
||||
* @return `True` is AVX2_VNNI instructions are available, `false` otherwise
|
||||
*/
|
||||
OPENVINO_RUNTIME_API bool with_cpu_x86_avx2_vnni();
|
||||
|
||||
/**
|
||||
* @brief Checks whether CPU supports AVX 512 capability
|
||||
* @ingroup ov_dev_api_system_conf
|
||||
|
@ -56,6 +56,10 @@ bool with_cpu_x86_avx2() {
|
||||
return get_cpu_info().has(Xbyak::util::Cpu::tAVX2);
|
||||
}
|
||||
|
||||
bool with_cpu_x86_avx2_vnni() {
|
||||
return get_cpu_info().has(Xbyak::util::Cpu::tAVX2 | Xbyak::util::Cpu::tAVX_VNNI);
|
||||
}
|
||||
|
||||
bool with_cpu_x86_avx512f() {
|
||||
return get_cpu_info().has(Xbyak::util::Cpu::tAVX512F);
|
||||
}
|
||||
|
@ -152,7 +152,6 @@ Node::Node(const std::shared_ptr<ngraph::Node>& op,
|
||||
str != "cpu:unknown")
|
||||
IE_THROW() << "Unsupported CPU implementation " << str << " for node " << getName();
|
||||
}
|
||||
// add default primitive priorities as a fallback for the custom ones
|
||||
const auto& defaultImplPriorities = getDefaultImplPriority();
|
||||
customImplPriorities.insert(customImplPriorities.end(), defaultImplPriorities.begin(), defaultImplPriorities.end());
|
||||
}
|
||||
@ -666,11 +665,11 @@ void Node::initSupportedPrimitiveDescriptors() {
|
||||
};
|
||||
|
||||
/* When custom implementation priorities are NOT defined it is enough to
|
||||
* just use the first implementation from the priority list.
|
||||
* When custom implementation priorities are defined, all the implementations should be considered,
|
||||
* since custom implementations can be not available at all, so a fallback to the default ones must happen
|
||||
* To achive the fallback, it is necessary to create a supported primitive descriptor for each implementation
|
||||
* since oneDNN primitive is mutating while iterating */
|
||||
* just use the first implementation from the priority list.
|
||||
* When custom implementation priorities are defined, all the implementations should be considered,
|
||||
* since custom implementations can be not available at all, so a fallback to the default ones must happen
|
||||
* To achive the fallback, it is necessary to create a supported primitive descriptor for each implementation
|
||||
* since oneDNN primitive is mutating while iterating */
|
||||
|
||||
for (auto& desc : descs) {
|
||||
auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
|
||||
|
@ -1563,7 +1563,7 @@ void Convolution::initializeInputZeroPoints(const uint8_t* inputZpData, const si
|
||||
if (inputZpData[j] != inputZpData[0])
|
||||
inputZeroPointType = zpType::PerChannel;
|
||||
}
|
||||
// Only enable per-tensor zero point on avx512-amx and avx512-core.
|
||||
// Only enable per-tensor zero point on avx512-amx and avx512-core-vnni.
|
||||
// If zero point is pertensor, both legacy zp and stock zp
|
||||
// would be passed into conv node. The conv node would determine how to create
|
||||
// post-ops attribute and prioritize to choose final onednn kernel.
|
||||
|
@ -661,6 +661,7 @@ const std::vector<impl_desc_type>& FullyConnected::getDefaultImplPriority() {
|
||||
impl_desc_type::brgemm_sparse_avx512_amx,
|
||||
impl_desc_type::brgemm_avx512_amx,
|
||||
impl_desc_type::brgemm_avx512,
|
||||
impl_desc_type::brgemm_avx2,
|
||||
impl_desc_type::gemm_blas,
|
||||
impl_desc_type::gemm_avx512,
|
||||
impl_desc_type::gemm_avx2,
|
||||
@ -833,7 +834,6 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
|
||||
|
||||
for (auto& desc : descs) {
|
||||
auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
|
||||
|
||||
const bool first_match = customImplPriorities.empty();
|
||||
DnnlExtensionUtils::for_each_implementation(desc,
|
||||
first_match,
|
||||
|
@ -537,7 +537,6 @@ void MatMul::initSupportedPrimitiveDescriptors() {
|
||||
|
||||
for (auto& desc : descs) {
|
||||
auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
|
||||
|
||||
const bool first_match = customImplPriorities.empty();
|
||||
DnnlExtensionUtils::for_each_implementation(desc,
|
||||
first_match,
|
||||
@ -702,6 +701,7 @@ const std::vector<impl_desc_type>& MatMul::getDefaultImplPriority() {
|
||||
impl_desc_type::unknown,
|
||||
impl_desc_type::brgemm_avx512_amx,
|
||||
impl_desc_type::brgemm_avx512,
|
||||
impl_desc_type::brgemm_avx2,
|
||||
impl_desc_type::gemm_acl,
|
||||
impl_desc_type::gemm_blas,
|
||||
impl_desc_type::gemm_avx512,
|
||||
|
@ -652,7 +652,6 @@ void Pooling::initSupportedPrimitiveDescriptors() {
|
||||
|
||||
for (auto& desc : descs) {
|
||||
auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
|
||||
|
||||
const bool first_match = customImplPriorities.empty();
|
||||
DnnlExtensionUtils::for_each_implementation(desc,
|
||||
first_match,
|
||||
|
@ -10,10 +10,12 @@
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "openvino/core/visibility.hpp"
|
||||
#include <shared_test_classes/single_layer/convolution.hpp>
|
||||
#include "utils/general_utils.h"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ov::test;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
using LayerTestsDefinitions::convSpecificParams;
|
||||
@ -219,8 +221,13 @@ TEST_P(ConvolutionLayerCPUTest, CompareWithRefs) {
|
||||
}
|
||||
|
||||
if (!priority.empty()) {
|
||||
// Skip all the brgconv avx2 tests for now. Current brgconv_avx2 is disabled due to perf regression[CVS-105756].
|
||||
// This convolution test code has already covered brgconv avx2 primitive.
|
||||
// @todo: Remove this once brgconv_avx2 is enabled for convolution node.
|
||||
if (priority[0].find("brgconv_avx2") != std::string::npos)
|
||||
GTEST_SKIP() << "Disabled test due to the brgconv_avx2 is not enabled." << std::endl;
|
||||
// Skip tests for brgconv convolution where kernel size = 1x1
|
||||
if (priority[0] == "brgconv_avx512" || priority[0] == "brgconv_avx512_amx") {
|
||||
if (one_of(priority[0], "brgconv_avx512", "brgconv_avx512_amx", "brgconv_avx2")) {
|
||||
bool is_1x1 = true;
|
||||
for (const auto &i : kernel) {
|
||||
if (i != 1) {
|
||||
@ -826,6 +833,7 @@ const std::vector<CPUSpecificParams> CPUParams_1D = {
|
||||
conv_avx512_1D,
|
||||
conv_sse42_1D_nspc,
|
||||
conv_avx2_1D_nspc,
|
||||
conv_avx2_1D_nspc_brgconv,
|
||||
conv_avx512_1D_nspc,
|
||||
conv_avx512_1D_nspc_brgconv
|
||||
};
|
||||
@ -934,6 +942,7 @@ const std::vector<CPUSpecificParams> CPUParams_2D = {
|
||||
conv_avx512_2D,
|
||||
conv_sse42_2D_nspc,
|
||||
conv_avx2_2D_nspc,
|
||||
conv_avx2_2D_nspc_brgconv,
|
||||
conv_avx512_2D_nspc,
|
||||
conv_avx512_2D_nspc_brgconv
|
||||
};
|
||||
@ -1211,6 +1220,7 @@ const std::vector<CPUSpecificParams> CPUParams_3D = {
|
||||
conv_avx2_3D,
|
||||
conv_avx512_3D,
|
||||
conv_avx2_3D_nspc,
|
||||
conv_avx2_3D_nspc_brgconv,
|
||||
conv_avx512_3D_nspc,
|
||||
conv_avx512_3D_nspc_brgconv
|
||||
};
|
||||
@ -1394,6 +1404,7 @@ const std::vector<CPUSpecificParams> CPUParams_1x1_1D = {
|
||||
conv_avx512_1D_1x1,
|
||||
conv_sse42_1D_1x1_nspc,
|
||||
conv_avx2_1D_1x1_nspc,
|
||||
conv_avx2_1D_1x1_nspc_brgconv,
|
||||
conv_avx512_1D_1x1_nspc,
|
||||
conv_avx512_1D_1x1_nspc_brgconv
|
||||
};
|
||||
@ -1459,6 +1470,7 @@ const std::vector<CPUSpecificParams> CPUParams_1x1_2D = {
|
||||
conv_avx512_2D_1x1,
|
||||
conv_sse42_2D_1x1_nspc,
|
||||
conv_avx2_2D_1x1_nspc,
|
||||
conv_avx2_2D_1x1_nspc_brgconv,
|
||||
conv_avx512_2D_1x1_nspc,
|
||||
conv_avx512_2D_1x1_nspc_brgconv
|
||||
};
|
||||
|
@ -236,10 +236,14 @@ std::vector<CPUSpecificParams> filterSpecificParams() {
|
||||
return specificParams;
|
||||
}
|
||||
|
||||
std::vector<CPUSpecificParams> filterSpecificParams_Brgemm() {
|
||||
//For FP32 precision, FC has brgemm avx2 support but Matmul doen't have brgemm avx2.
|
||||
//Need to specify tryBrgAVX2 based on test case.
|
||||
std::vector<CPUSpecificParams> filterSpecificParams_Brgemm(bool tryBrgAVX2 = false) {
|
||||
std::vector<CPUSpecificParams> specificParams;
|
||||
if (with_cpu_x86_avx512_core()) {
|
||||
specificParams.push_back(CPUSpecificParams{{}, {}, {"brgemm_avx512"}, "brgemm_avx512"});
|
||||
} else if (tryBrgAVX2 && with_cpu_x86_avx2()) {
|
||||
specificParams.push_back(CPUSpecificParams{{}, {}, {"brgemm_avx2"}, "brgemm_avx2"});
|
||||
}
|
||||
|
||||
return specificParams;
|
||||
@ -683,7 +687,7 @@ const auto fullyConnectedParams2D_Brgemm_smoke = ::testing::Combine(::testing::V
|
||||
const auto testParams2D_Brgemm_smoke = ::testing::Combine(fullyConnectedParams2D_Brgemm_smoke,
|
||||
::testing::Values(MatMulNodeType::FullyConnected),
|
||||
::testing::ValuesIn(fusingParamsSet2D_Brgemm_smoke),
|
||||
::testing::ValuesIn(filterSpecificParams_Brgemm()));
|
||||
::testing::ValuesIn(filterSpecificParams_Brgemm(true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_Brgemm, MatMulLayerCPUTest, testParams2D_Brgemm_smoke, MatMulLayerCPUTest::getTestCaseName);
|
||||
|
||||
@ -836,7 +840,7 @@ const auto fullyConnectedParams2D_Brgemm_nightly = ::testing::Combine(::testing:
|
||||
const auto testParams2D_Brgemm_nightly = ::testing::Combine(fullyConnectedParams2D_Brgemm_nightly,
|
||||
::testing::Values(MatMulNodeType::FullyConnected),
|
||||
::testing::ValuesIn(fusingParamsSet2D_nightly),
|
||||
::testing::ValuesIn(filterSpecificParams_Brgemm()));
|
||||
::testing::ValuesIn(filterSpecificParams_Brgemm(true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(nightly_FC_2D_Brgemm, MatMulLayerCPUTest, testParams2D_Brgemm_nightly, MatMulLayerCPUTest::getTestCaseName);
|
||||
|
||||
|
@ -0,0 +1,161 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using ElementType = ov::element::Type_t;
|
||||
using MatmulBrgemmInt8TestParams = std::tuple<SizeVector, // input shape
|
||||
bool, // true: FullyConnected false: Matmul
|
||||
ElementType, // input u8/s8
|
||||
ElementType, // output f32/u8/s8
|
||||
CPUSpecificParams // brgemm/jit primitive implement type
|
||||
>;
|
||||
|
||||
// subgraph:
|
||||
// fq->MatMul/FullyConnected->[fq]
|
||||
// can cover brgemm avx2:
|
||||
// (u8/s8 + s8)->f32
|
||||
// (u8/s8 + s8)->u8/s8
|
||||
class MatmulBrgemmInt8Test : public testing::WithParamInterface<MatmulBrgemmInt8TestParams>, public CpuTestWithFusing,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<MatmulBrgemmInt8TestParams> obj) {
|
||||
SizeVector supportedInputShapes;
|
||||
bool isFC;
|
||||
ElementType inType;
|
||||
ElementType outType;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(supportedInputShapes, isFC, inType, outType, cpuParams) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=" << CommonTestUtils::vec2str(supportedInputShapes) << "_";
|
||||
result << (isFC ? "FullyConnected" : "MatMul") << "_";
|
||||
result << "InputType=" << inType << "_";
|
||||
result << "OutputType=" << outType << "_";
|
||||
result << CPUTestsBase::getTestCaseName(cpuParams);
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
bool isFC;
|
||||
std::string nameMatmul = "TestedMatmul";
|
||||
ElementType inType;
|
||||
ElementType outType;
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
SizeVector inShapes;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(inShapes, isFC, inType, outType, cpuParams) = this->GetParam();
|
||||
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
||||
const auto ngPrec = element::f32;
|
||||
auto inputParams = builder::makeParams(ngPrec, {inShapes});
|
||||
|
||||
std::shared_ptr<Node> fq1;
|
||||
std::shared_ptr<Node> matMul;
|
||||
std::shared_ptr<Node> nodeBeforeConv;
|
||||
selectedType = makeSelectedTypeStr(selectedType, ElementType::i8);
|
||||
if (inType == ElementType::u8)
|
||||
fq1 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
|
||||
else
|
||||
fq1 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||
|
||||
if (isFC) {
|
||||
ngraph::Shape weightShape = inShapes;
|
||||
std::swap(weightShape[0], weightShape[1]);
|
||||
auto weightsNode = ngraph::builder::makeConstant(ngPrec, weightShape, std::vector<float>{0.0f}, true);
|
||||
auto fq2 = ngraph::builder::makeFakeQuantize(weightsNode, ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||
auto fc = std::make_shared<ngraph::opset1::MatMul>(fq1, fq2, false, false);
|
||||
fc->get_rt_info() = getCPUInfo();
|
||||
fc->set_friendly_name(nameMatmul);
|
||||
auto biasWeightsNode = ngraph::builder::makeConstant(ngPrec, {}, std::vector<float>{0.0f}, true);
|
||||
matMul = std::make_shared<ngraph::opset1::Add>(fc, biasWeightsNode);
|
||||
} else {
|
||||
auto fq2 = ngraph::builder::makeFakeQuantize(inputParams[0], ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||
matMul = builder::makeMatMul(fq1, fq2, false, true);
|
||||
matMul->get_rt_info() = getCPUInfo();
|
||||
matMul->set_friendly_name(nameMatmul);
|
||||
}
|
||||
if (outType == ElementType::u8)
|
||||
nodeBeforeConv = ngraph::builder::makeFakeQuantize(matMul, ngPrec, 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
|
||||
else if (outType == ElementType::i8)
|
||||
nodeBeforeConv = ngraph::builder::makeFakeQuantize(matMul, ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||
else
|
||||
nodeBeforeConv = matMul;
|
||||
|
||||
// matmul->fq->matmul can cover x8*s8->x8 case
|
||||
auto filterWeightsShape = matMul->get_output_shape(0);
|
||||
auto filterWeightsNode = ngraph::builder::makeConstant(element::f32, filterWeightsShape, std::vector<float>{}, true);
|
||||
auto fq3 = ngraph::builder::makeFakeQuantize(filterWeightsNode, ngPrec, 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||
// only matmul avx2 support s8*s8 input
|
||||
auto matMul2 = builder::makeMatMul(nodeBeforeConv, fq3, false, false);
|
||||
|
||||
function = makeNgraphFunction(ngPrec, inputParams, matMul2, "MatmulBrgemmInt8");
|
||||
}
|
||||
|
||||
void CheckNode(std::shared_ptr<const ov::Model> function, const std::string& nodeName) {
|
||||
ASSERT_NE(nullptr, function);
|
||||
for (const auto &node : function->get_ops()) {
|
||||
const auto & rtInfo = node->get_rt_info();
|
||||
auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
|
||||
auto it = rtInfo.find(paramName);
|
||||
IE_ASSERT(rtInfo.end() != it);
|
||||
return it->second.as<std::string>();
|
||||
};
|
||||
if (node->get_friendly_name() == nodeName) {
|
||||
auto primType = getExecValue(ExecGraphInfoSerialization::IMPL_TYPE);
|
||||
ASSERT_TRUE(primTypeCheck(primType)) << "primType is unexpected: " << primType << " Expected: " << selectedType;
|
||||
ASSERT_EQ(node->get_output_element_type(0), outType);
|
||||
ASSERT_EQ(node->get_input_element_type(0), inType);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(MatmulBrgemmInt8Test, CompareWithRefs) {
|
||||
// only cover avx2_vnni
|
||||
if (InferenceEngine::with_cpu_x86_avx512_core() || !InferenceEngine::with_cpu_x86_avx2_vnni())
|
||||
GTEST_SKIP();
|
||||
|
||||
Run();
|
||||
InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo();
|
||||
auto exec = execGraphInfo.getFunction();
|
||||
CheckNode(exec, nameMatmul);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<SizeVector> supportedInputShapes = {
|
||||
{16, 32},
|
||||
{17, 15},
|
||||
};
|
||||
|
||||
const std::vector<CPUSpecificParams>matmulSpecificFilterParams = {
|
||||
{{}, {}, {"brgemm_avx2"}, "brgemm_avx2"},
|
||||
{{}, {}, {"jit_gemm"}, "jit_gemm"}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_matmulBrgemmInt8, MatmulBrgemmInt8Test,
|
||||
::testing::Combine(::testing::ValuesIn(supportedInputShapes),
|
||||
::testing::ValuesIn({true, false}),
|
||||
::testing::ValuesIn({ElementType::u8, ElementType::i8}),
|
||||
::testing::ValuesIn({ElementType::f32, ElementType::u8, ElementType::i8}),
|
||||
::testing::ValuesIn(matmulSpecificFilterParams)),
|
||||
MatmulBrgemmInt8Test::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -58,6 +58,10 @@ namespace CPUTestUtils {
|
||||
const auto conv_avx2_dw_2D_nspc = CPUSpecificParams{{nhwc}, {nhwc}, {"jit_avx2_dw"}, "jit_avx2_dw"};
|
||||
const auto conv_avx2_dw_3D_nspc = CPUSpecificParams{{ndhwc}, {ndhwc}, {"jit_avx2_dw"}, "jit_avx2_dw"};
|
||||
|
||||
const auto conv_avx2_1D_nspc_brgconv = CPUSpecificParams{{nwc}, {nwc}, {"brgconv_avx2"}, "brgconv_avx2"};
|
||||
const auto conv_avx2_2D_nspc_brgconv = CPUSpecificParams{{nhwc}, {nhwc}, {"brgconv_avx2"}, "brgconv_avx2"};
|
||||
const auto conv_avx2_3D_nspc_brgconv = CPUSpecificParams{{ndhwc}, {ndhwc}, {"brgconv_avx2"}, "brgconv_avx2"};
|
||||
|
||||
const auto conv_avx512_1D = CPUSpecificParams{{nCw16c}, {nCw16c}, {"jit_avx512"}, "jit_avx512"};
|
||||
const auto conv_avx512_2D = CPUSpecificParams{{nChw16c}, {nChw16c}, {"jit_avx512"}, "jit_avx512"};
|
||||
const auto conv_avx512_3D = CPUSpecificParams{{nCdhw16c}, {nCdhw16c}, {"jit_avx512"}, "jit_avx512"};
|
||||
@ -97,6 +101,7 @@ namespace CPUTestUtils {
|
||||
|
||||
const auto conv_sse42_1D_1x1_nspc = CPUSpecificParams{{nwc}, {nwc}, {"jit_sse42_1x1"}, "jit_sse42_1x1"};
|
||||
const auto conv_avx2_1D_1x1_nspc = CPUSpecificParams{{nwc}, {nwc}, {"jit_avx2_1x1"}, "jit_avx2_1x1"};
|
||||
const auto conv_avx2_1D_1x1_nspc_brgconv = CPUSpecificParams{{nwc}, {nwc}, {"brgconv_avx2_1x1"}, "brgconv_avx2_1x1"};
|
||||
const auto conv_avx512_1D_1x1_nspc = CPUSpecificParams{{nwc}, {nwc}, {"jit_avx512_1x1"}, "jit_avx512_1x1"};
|
||||
const auto conv_avx512_1D_1x1_nspc_brgconv = CPUSpecificParams{{nwc}, {nwc}, {"brgconv_avx512_1x1"}, "brgconv_avx512_1x1"};
|
||||
const auto conv_avx512_1D_1x1_nspc_brgconv_amx = CPUSpecificParams{{nwc}, {nwc}, {"brgconv_avx512_amx_1x1"}, "brgconv_avx512_amx_1x1"};
|
||||
@ -107,6 +112,7 @@ namespace CPUTestUtils {
|
||||
|
||||
const auto conv_sse42_2D_1x1_nspc = CPUSpecificParams{{nhwc}, {nhwc}, {"jit_sse42_1x1"}, "jit_sse42_1x1"};
|
||||
const auto conv_avx2_2D_1x1_nspc = CPUSpecificParams{{nhwc}, {nhwc}, {"jit_avx2_1x1"}, "jit_avx2_1x1"};
|
||||
const auto conv_avx2_2D_1x1_nspc_brgconv = CPUSpecificParams{{nhwc}, {nhwc}, {"brgconv_avx2_1x1"}, "brgconv_avx2_1x1"};
|
||||
const auto conv_avx512_2D_1x1_nspc = CPUSpecificParams{{nhwc}, {nhwc}, {"jit_avx512_1x1"}, "jit_avx512_1x1"};
|
||||
const auto conv_avx512_2D_1x1_nspc_brgconv = CPUSpecificParams{{nhwc}, {nhwc}, {"brgconv_avx512_1x1"}, "brgconv_avx512_1x1"};
|
||||
const auto conv_avx512_2D_1x1_nspc_brgconv_amx = CPUSpecificParams{{nhwc}, {nhwc}, {"brgconv_avx512_amx_1x1"}, "brgconv_avx512_amx_1x1"};
|
||||
|
Loading…
Reference in New Issue
Block a user