[CPU] MLAS backend integration (#17885)

- currently enabled only for FP32 FullyConnected node on x86 CPUs
This commit is contained in:
Zhang Yi 2023-07-26 15:40:34 +08:00 committed by GitHub
parent 97b4b13074
commit 1c0c929231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 592 additions and 8 deletions

3
.gitmodules vendored
View File

@ -72,3 +72,6 @@
[submodule "ARMComputeLibrary"]
path = src/plugins/intel_cpu/thirdparty/ComputeLibrary
url = https://github.com/ARM-software/ComputeLibrary.git
[submodule "src/plugins/intel_cpu/thirdparty/mlas"]
path = src/plugins/intel_cpu/thirdparty/mlas
url = https://github.com/openvinotoolkit/mlas.git

View File

@ -1399,3 +1399,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-------------------------------------------------------------
21 MLAS (https://github.com/microsoft/onnxruntime)
MIT License
Copyright (c) Microsoft Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -115,6 +115,7 @@ tolerance_map = {
"GPT2": {"atol": 5e-06, "rtol": 0.01},
"GPT-2-LM-HEAD": {"atol": 4e-06},
"test_retinanet_resnet101": {"atol": 1.3e-06},
"resnet34-v1-7" : {"atol": 1e-5}
}
def tolerance_map_key_in_model_path(path):

View File

@ -20,6 +20,8 @@ elseif(OV_COMPILER_IS_CLANG)
ie_add_compiler_flags(-Wno-delete-non-abstract-non-virtual-dtor)
endif()
# enbale mlas for X86 cpus only
ie_dependent_option(ENABLE_MLAS_FOR_CPU "MLAS GEMM for OpenVINO CPU Plugin" ON "X86 OR X86_64" OFF)
add_subdirectory(thirdparty)
if(WIN32)
@ -64,6 +66,10 @@ if(NOT (AARCH64 OR ARM))
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/*)
endif()
if (NOT ENABLE_MLAS_FOR_CPU)
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/mlas/*)
endif()
file(GLOB_RECURSE FILES_TO_REMOVE ${EXCLUDE_PATHS})
list(REMOVE_ITEM SOURCES ${FILES_TO_REMOVE})
list(REMOVE_ITEM HEADERS ${FILES_TO_REMOVE})
@ -94,8 +100,12 @@ target_link_libraries(${TARGET_NAME} PRIVATE dnnl
target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_EXTENSION_API)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
if (ENABLE_MLAS_FOR_CPU)
target_link_libraries(${TARGET_NAME} PRIVATE mlas)
target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $<TARGET_PROPERTY:mlas,INCLUDE_DIRECTORIES>)
add_definitions(-DOV_CPU_WITH_MLAS)
endif()
target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $<TARGET_PROPERTY:dnnl,INCLUDE_DIRECTORIES>)
# Cross compiled function
# TODO: The same for proposal, proposalONNX, topk
cross_compiled_file(${TARGET_NAME}
@ -133,6 +143,10 @@ if(BUILD_SHARED_LIBS)
$<TARGET_PROPERTY:openvino::conditional_compilation,INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(${TARGET_NAME}_obj SYSTEM PUBLIC $<TARGET_PROPERTY:dnnl,INCLUDE_DIRECTORIES>)
if(ENABLE_MLAS_FOR_CPU)
target_include_directories(${TARGET_NAME}_obj SYSTEM PUBLIC $<TARGET_PROPERTY:mlas,INCLUDE_DIRECTORIES>)
endif()
set_ie_threading_interface_for(${TARGET_NAME}_obj)

View File

@ -705,6 +705,9 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
if (parent->getType() == Type::Convert && parent->isConstant() && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
&& parent->getOriginalInputPrecisionAtPort(0) == Precision::FP16
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16)) {
auto childNode = parent->getChildEdgeAt(0)->getChild();
// set correct weight precision
childNode->setOriginalInputPrecisionAtPort(1, parent->getOriginalInputPrecisionAtPort(0));
graph.DropNode(parent);
}
}

View File

@ -0,0 +1,94 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "sgemm.hpp"
#include <string>
#include <vector>
#include "mlas.h"
#include "onednn/dnnl.h"
#include "openvino/core/parallel.hpp"
#include "thread_pool.hpp"
namespace ov {
namespace intel_cpu {
size_t mlas_sgemm_pack_get_size(const int64_t N, const int64_t K) {
return MlasGemmPackBSize(N, K);
}
void mlas_sgemm_pack(const char* transb,
const int64_t N,
const int64_t K,
const int64_t ldb,
const float* src,
float* dst) {
MlasGemmPackB(*transb == 'T' ? CblasTrans : CblasNoTrans, N, K, src, ldb, dst);
}
void mlas_sgemm(const char* transa,
const char* transb,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const float* A,
const int64_t lda,
const float* B,
const int64_t ldb,
const float beta,
float* C,
const int64_t ldc,
size_t thread_num) {
// C = alpha*op( A )op( B ) + beta * C
MLAS_SGEMM_DATA_PARAMS sgemmParam;
sgemmParam.BIsPacked = false;
sgemmParam.A = A;
sgemmParam.lda = lda;
sgemmParam.B = B;
sgemmParam.ldb = ldb;
sgemmParam.C = C;
sgemmParam.ldc = ldc;
sgemmParam.alpha = alpha;
sgemmParam.beta = beta;
auto _transa = *transa == 'N' ? CblasNoTrans : CblasTrans;
auto _transb = *transb == 'N' ? CblasNoTrans : CblasTrans;
ov::cpu::OVMlasThreadPool threadPool(0 == thread_num ? parallel_get_num_threads() : thread_num);
MlasGemmBatch(_transa, _transb, M, N, K, &sgemmParam, 1, &threadPool);
}
void mlas_sgemm_compute(const char* transa,
const char* transb,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const float* A,
const int64_t lda,
const float* B,
const int64_t ldb,
const float beta,
float* C,
const int64_t ldc,
const float* bias,
size_t thread_num) {
// C = alpha*op( A )op( B ) + beta * C
ov::cpu::OVMlasThreadPool threadPool(0 == thread_num ? parallel_get_num_threads() : thread_num);
MLAS_SGEMM_DATA_PARAMS sgemmParam;
sgemmParam.BIsPacked = true;
sgemmParam.A = A;
sgemmParam.lda = lda;
sgemmParam.B = B;
sgemmParam.ldb = ldb;
sgemmParam.C = C;
sgemmParam.ldc = ldc;
sgemmParam.alpha = alpha;
sgemmParam.beta = beta;
sgemmParam.bias = bias;
auto _transa = *transa == 'N' ? CblasNoTrans : CblasTrans;
auto _transb = *transb == 'N' ? CblasNoTrans : CblasTrans;
MlasGemmBatch(_transa, _transb, M, N, K, &sgemmParam, 1, &threadPool);
}
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,109 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstddef>
#include <cstdint>
namespace ov {
namespace intel_cpu {
/**
* @brief Computes the length in bytes for the packed matrix B buffer(SGEMM).
*
* @param N Supplies the number of columns of matrix B.
* @param K Supplies the number of rows of matrix B.
* @return bytes of the packing buffer
*/
size_t mlas_sgemm_pack_get_size(const int64_t N, const int64_t K);
/**
* @brief Packs the contents of matrix B
*
* @param transb T for transpose B, N for none-tranpose B
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param ldb Supplies the first dimension of matrix B.
* @param src Supplies the address of matrix B
* @param dst Supplies pointer to prePacked B buffer
*/
void mlas_sgemm_pack(const char* transb,
const int64_t N,
const int64_t K,
const int64_t ldb,
const float* src,
float* dst);
/**
* @brief SGEMM with planar B matrix
*
* @param transa T for transpose A, N for none-tranpose A.
* @param transb T for transpose B, N for none-tranpose B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param thread_num 0 for all threads, otherwise use thread_num
*/
void mlas_sgemm(const char* transa,
const char* transb,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const float* A,
const int64_t lda,
const float* B,
const int64_t ldb,
const float beta,
float* C,
const int64_t ldc,
size_t thread_num = 0);
/**
* @brief SGEMM with B matrix prepacked
*
* @param transa T for transpose A, N for none-tranpose A.
* @param transb T for transpose B, N for none-tranpose B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param bias Supplies the address of by-channel bias
* @param thread_num 0 for all threads, otherwise use thread_num
*/
void mlas_sgemm_compute(const char* transa,
const char* transb,
const int64_t M,
const int64_t N,
const int64_t K,
const float alpha,
const float* A,
const int64_t lda,
const float* B,
const int64_t ldb,
const float beta,
float* C,
const int64_t ldc,
const float* bias = nullptr,
size_t thread_num = 0);
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,33 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "thread_pool.hpp"
#include "onednn/dnnl.h"
#include "openvino/core/parallel.hpp"
// This function impl the forward declaration in MLAS
size_t getCacheSizeMlas(int level, bool perCore) {
return dnnl::utils::get_cache_size(level, perCore);
}
namespace ov {
namespace cpu {
size_t OVMlasThreadPool::DegreeOfParallelism() {
// threadpool nullptr means single threaded
return threadNum;
}
void OVMlasThreadPool::TrySimpleParallelFor(const std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn) {
ov::parallel_nt(threadNum, [&](const size_t ithr, const size_t nthr) {
std::ptrdiff_t start = 0, end = 0;
ov::splitter(total, nthr, ithr, start, end);
for (std::ptrdiff_t i = start; i < end; i++) {
fn(i);
}
});
}
}; // namespace cpu
}; // namespace ov

View File

@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstddef>
#include <cstdint>
#include <functional>
#include "mlas.h"
namespace ov {
namespace cpu {
class OVMlasThreadPool : public IMlasThreadPool {
public:
OVMlasThreadPool() = delete;
explicit OVMlasThreadPool(const size_t& threadNum) : threadNum(threadNum) {}
size_t DegreeOfParallelism() override;
void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn) override;
public:
// the actual threads used for sgemm
size_t threadNum = 0;
};
}; // namespace cpu
}; // namespace ov

View File

@ -471,6 +471,7 @@ std::string Node::getPrimitiveDescriptorType() const {
SEARCH_TYPE(avx);
SEARCH_TYPE(sse42);
SEARCH_TYPE(blas);
SEARCH_TYPE(mlas);
SEARCH_TYPE(any);
SEARCH_TYPE(uni);

View File

@ -29,6 +29,10 @@
#include <string>
#include <vector>
#ifdef OV_CPU_WITH_MLAS
#include "mlas/sgemm.hpp"
#endif
using namespace dnnl;
using namespace InferenceEngine;
@ -275,6 +279,39 @@ void FullyConnected::getSupportedDescriptors() {
inDims = isDynamicNode() ? makeDummyInputDims() : getInputShapeAtPort(DATA_ID).getStaticDims();
outDims = isDynamicNode() ? makeDummyOutputDims(inDims) : getOutputShapeAtPort(0).getStaticDims();
#ifdef OV_CPU_WITH_MLAS
// MLAS doesn't support post-ops fusing and only supports FP32. INT8 is not enabled yet
// Disable MLAS when FC could fuse post-ops
useMlas = !useSparseWeights &&
(inputDataType == memory::data_type::f32 && weightsDataType == memory::data_type::f32) &&
fusedWith.empty();
auto wgtDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
// MLAS cannot support weight dims > 2, e.g. [1,64,9,9] * [10,64,9,9]
if (useMlas && wgtDims.size() > 2) {
bool allOnes = true;
for (size_t i = 2; i < wgtDims.size(); i++) {
allOnes = allOnes && wgtDims[i] == 1;
}
useMlas = useMlas && allOnes;
}
if (useMlas && withBiases) {
const auto& biasDims = getInputShapeAtPort(BIAS_ID).getStaticDims();
bool isByChannel = biasDims.back() == outDims.back();
for (size_t i = 0; i < biasDims.size() - 1; i++) {
isByChannel = isByChannel && biasDims[i] == 1;
}
useMlas = useMlas && isByChannel;
}
#endif
#ifdef CPU_DEBUG_CAPS
// Select Sgemm type by ENV MLAS/ONEDNN, MLAS is used by default
if (getenv("OV_CPU_FC_EXEC_TYPE")) {
if (std::string(getenv("OV_CPU_FC_EXEC_TYPE")) != "MLAS") {
useMlas = false;
}
}
#endif
if (useMlas) return;
for (auto format : getAvailableFormatsForDims(getInputShapeAtPort(0))) {
auto in_candidate = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(inDims), inputDataType, format);
@ -284,7 +321,58 @@ void FullyConnected::getSupportedDescriptors() {
}
}
#ifdef OV_CPU_WITH_MLAS
void FullyConnected::prepackMLASWeight() {
auto prepareMLASWeight = [&](const int64_t N, const int64_t K) {
if (!getParentEdgeAt(WEIGHTS_ID)->getParent()->isConstant())
IE_THROW() << "Weight input is not const for node " << getName() << ".";
auto weightsMem = getParentEdgeAt(WEIGHTS_ID)->getMemoryPtr();
if (!weightsMem)
IE_THROW() << "Cannot get const weights edgeMem for node " << getName() << ".";
auto packedBsize = mlas_sgemm_pack_get_size(N, K);
MemoryPtr ptr;
auto create = [&]() {
float* weightPtr = reinterpret_cast<float*>(weightsMem->getData());
size_t ldb = K;
MemoryPtr _ptr =
std::make_shared<Memory>(getEngine(),
intel_cpu::CpuBlockedMemoryDesc(Precision::I8, intel_cpu::Shape{packedBsize}));
float* prepackedDst = reinterpret_cast<float*>(_ptr->getData());
mlas_sgemm_pack("T", N, K, ldb, weightPtr, prepackedDst);
return _ptr;
};
auto weightCache = context->getWeightsCache();
if (weightCache != nullptr) {
std::string format = "gemm_mlas_" + std::to_string(N) + "_" + std::to_string(K);
const std::string string_hash = getName() + "_" + format + "_" + std::to_string(weightsMem->getSize()) +
"_" + std::to_string(reinterpret_cast<uint64_t>(weightsMem->getData()));
ptr = *weightCache->findOrCreate(string_hash, create);
} else {
ptr = create();
}
return ptr;
};
const auto& wgtDims = getParentEdgeAt(WEIGHTS_ID)->getMemoryPtr()->getStaticDims();
// Weight is transpoed by MatMulConstTransposesExtraction
// K is the IC of weight
// the weight is reshaped to [-1, K] in ConvertMatMulToFC
K = wgtDims[1];
N = wgtDims[0];
mlasPackedPtr = prepareMLASWeight(N, K);
}
#endif
void FullyConnected::createPrimitive() {
#ifdef OV_CPU_WITH_MLAS
if (useMlas) {
Node::createPrimitive();
prepackMLASWeight();
return;
}
#endif
setPostOps(attr, outDims);
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
Node::createPrimitive();
@ -308,7 +396,18 @@ void FullyConnected::prepareParams() {
NodeDesc *selected_pd = getSelectedPrimitiveDescriptor();
if (selected_pd == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
#ifdef OV_CPU_WITH_MLAS
// M should be normalized and updated
if (useMlas) {
outDims = dstMemPtr->getStaticDims();
if (outDims.size() > 2) {
M = std::accumulate(outDims.begin(), outDims.end() - 1, 1, std::multiplies<size_t>());
} else {
M = outDims[0];
}
return;
}
#endif
DnnlMemoryDescPtr weightDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weightDescIP);
DnnlMemoryDescCPtr biasDesc = nullptr;
if (biasMemPtr) {
@ -438,7 +537,39 @@ void FullyConnected::prepareParams() {
}
}
#ifdef OV_CPU_WITH_MLAS
void FullyConnected::executeMLAS() {
const auto dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
const auto src0MemPtr = getParentEdgeAt(0)->getMemoryPtr();
const auto biasMemPtr = withBiases ? getParentEdgeAt(BIAS_ID)->getMemoryPtr() : nullptr;
int64_t lda = K;
int64_t ldb = K;
int64_t ldc = N;
mlas_sgemm_compute("N",
"N",
M,
N,
K,
1.0f,
reinterpret_cast<float*>(src0MemPtr->getData()),
lda,
reinterpret_cast<float*>(mlasPackedPtr->getData()),
ldb,
0.0f,
reinterpret_cast<float*>(dstMemPtr->getData()),
ldc,
withBiases ? reinterpret_cast<float*>(biasMemPtr->getData()) : nullptr);
}
#endif
void FullyConnected::execute(dnnl::stream strm) {
#ifdef OV_CPU_WITH_MLAS
if (useMlas) {
executeMLAS();
return;
}
#endif
if (!execPtr) {
IE_THROW() << "Can't execute FullyConnected node with name: " << getName() << ", because executor is not compiled";
}
@ -652,7 +783,22 @@ void FullyConnected::createDescriptor(const std::vector<MemoryDescPtr> &inputDes
void FullyConnected::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
if (useMlas) {
auto dataPrecision = getOriginalInputPrecisionAtPort(0);
if (withBiases) {
addSupportedPrimDesc({{LayoutType::ncsp, dataPrecision},
{LayoutType::ncsp, dataPrecision},
{LayoutType::ncsp, dataPrecision}},
{{LayoutType::ncsp, dataPrecision}},
impl_desc_type::gemm_mlas);
} else {
addSupportedPrimDesc({{LayoutType::ncsp, dataPrecision},
{LayoutType::ncsp, dataPrecision}},
{{LayoutType::ncsp, dataPrecision}},
impl_desc_type::gemm_mlas);
}
return;
}
// 3D FC requires implicit reshape so strides should be defined
auto supportsUndefStridesAndOffset = [&]() {
return getOutputShapeAtPort(0).getRank() == 2;
@ -928,6 +1074,7 @@ bool FullyConnected::useSparseWeightsDecompression() {
return true;
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -100,6 +100,13 @@ private:
float weiSparseRate = 0.f;
bool useSparseWeightsDecompression();
VectorDims expectedBiasDims {};
bool useMlas = false;
#ifdef OV_CPU_WITH_MLAS
int64_t M, N, K;
MemoryPtr mlasPackedPtr = nullptr;
void executeMLAS();
void prepackMLASWeight();
#endif
};
} // namespace node

View File

@ -29,6 +29,7 @@ impl_desc_type parse_impl_name(std::string impl_desc_name) {
if ((res & impl_desc_type::brgemm) != impl_desc_type::brgemm)
SEARCH_WORD(gemm);
SEARCH_WORD(blas);
SEARCH_WORD(mlas);
SEARCH_WORD(sse42);
SEARCH_WORD_2(sse41, sse42);
SEARCH_WORD(avx2);
@ -118,6 +119,7 @@ const char* impl_type_to_string(impl_desc_type type) {
CASE(dw_acl);
CASE(gemm_acl);
CASE(winograd_acl);
CASE(gemm_mlas);
#undef CASE
return "unknown";

View File

@ -39,6 +39,8 @@ enum impl_desc_type {
winograd = 1<<24,
// sparse
sparse = 1<<25,
//mlas backend
mlas = 1<<26,
// real types
ref_any = ref | any,
@ -98,6 +100,7 @@ enum impl_desc_type {
dw_acl = _dw | acl,
gemm_acl = gemm | acl,
winograd_acl = winograd | acl,
gemm_mlas = gemm | mlas
};
const char * impl_type_to_string(impl_desc_type type);

View File

@ -182,7 +182,6 @@ TEST_P(MatMulLayerCPUTest, CompareWithRefs) {
}
}
}
run();
CheckPluginRelatedResults(compiledModel, cpuNodeType);
}
@ -193,14 +192,22 @@ namespace {
std::map<std::string, std::string> emptyAdditionalConfig;
std::vector<std::map<std::string, std::string>> additionalConfig {
#ifndef OV_CPU_WITH_MLAS
// FP32 precision is covered by MLAS
std::map<std::string, std::string>{/* empty config */},
#endif
{{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}}
};
std::vector<std::map<std::string, std::string>> filterAdditionalConfig_Brgemm() {
#ifndef OV_CPU_WITH_MLAS
// FP32 precision is covered by MLAS
std::vector<std::map<std::string, std::string>> additionalConfig = {
std::map<std::string, std::string>{/* empty config */}
std::map<std::string, std::string>{/* empty config */}
};
#else
std::vector<std::map<std::string, std::string>> additionalConfig = {};
#endif
if (with_cpu_x86_bfloat16()) {
additionalConfig.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}});
}
@ -257,6 +264,13 @@ std::vector<CPUSpecificParams> filterSpecificParams_Brgconv1x1() {
return specificParams;
}
std::vector<CPUSpecificParams> filterSpecificParams_MLAS() {
// replace with mlas primitive type
std::vector<CPUSpecificParams> specificParams;
specificParams.push_back(CPUSpecificParams{{}, {}, {"gemm_mlas"}, "gemm_mlas"});
return specificParams;
}
/* ============= FullyConnected ============= */
namespace fullyConnected {
@ -326,22 +340,30 @@ const std::vector<ShapeRelatedParams> IS2D_nightly = {
};
std::vector<fusingSpecificParams> fusingParamsSet2D_smoke {
// The following three patterns are convered by MLAS test
#ifndef OV_CPU_WITH_MLAS
emptyFusingSpec,
fusingBias,
fusingMultiplyPerChannel,
#endif
fusingFakeQuantizePerTensorRelu,
};
std::vector<fusingSpecificParams> fusingParamsSet2D_Brgemm_smoke {
// The following three patterns are convered by MLAS test
#ifndef OV_CPU_WITH_MLAS
emptyFusingSpec,
fusingBias,
fusingMultiplyPerChannel,
#endif
fusingFakeQuantizePerTensorRelu,
};
std::vector<fusingSpecificParams> fusingParamsSet2D_nightly {
fusingRelu,
fusingScaleShift, // EltwiseMulAdd fusing
#ifndef OV_CPU_WITH_MLAS
fusingScaleShift, //covered by MLAS
#endif
fusingPReluPerTensor,
fusingFakeQuantizePerChannelRelu,
};
@ -378,6 +400,26 @@ const auto testParams2DBF16_smoke = ::testing::Combine(::testing::Combine(::test
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D, MatMulLayerCPUTest, testParams2D_smoke, MatMulLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_BF16, MatMulLayerCPUTest, testParams2DBF16_smoke, MatMulLayerCPUTest::getTestCaseName);
#ifdef OV_CPU_WITH_MLAS
std::vector<fusingSpecificParams> fusingParamsSet2D_MLAS_smoke {
emptyFusingSpec,
fusingBias,
fusingMultiplyPerChannel
};
const auto testParams2D_MLAS_smoke = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS2D_smoke),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::CONSTANT),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(emptyAdditionalConfig)),
::testing::Values(MatMulNodeType::FullyConnected),
::testing::ValuesIn(fusingParamsSet2D_MLAS_smoke),
::testing::ValuesIn(filterSpecificParams_MLAS()));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_MLAS, MatMulLayerCPUTest, testParams2D_MLAS_smoke, MatMulLayerCPUTest::getTestCaseName);
#endif
const auto testParams2D_nightly = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS2D_nightly),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
@ -403,6 +445,24 @@ const auto testParams2DBF16_nightly = ::testing::Combine(::testing::Combine(::te
INSTANTIATE_TEST_SUITE_P(nightly_FC_2D, MatMulLayerCPUTest, testParams2D_nightly, MatMulLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_FC_2D_BF16, MatMulLayerCPUTest, testParams2DBF16_nightly, MatMulLayerCPUTest::getTestCaseName);
#ifdef OV_CPU_WITH_MLAS
std::vector<fusingSpecificParams> fusingParamsSet2D_MLAS_nightly {
fusingScaleShift
};
const auto testParams2D_MLAS_nightly = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS2D_nightly),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(helpers::InputLayerType::CONSTANT),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(emptyAdditionalConfig)),
::testing::Values(MatMulNodeType::FullyConnected),
::testing::ValuesIn(fusingParamsSet2D_MLAS_nightly),
::testing::ValuesIn(filterSpecificParams_MLAS()));
INSTANTIATE_TEST_SUITE_P(nightly_FC_2D_MLAS, MatMulLayerCPUTest, testParams2D_MLAS_nightly, MatMulLayerCPUTest::getTestCaseName);
#endif
const std::vector<ShapeRelatedParams> IS3D_smoke = {
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, false}},
{static_shapes_to_test_representation({{1, 32, 120}, {120, 5}}), {false, true}},
@ -470,9 +530,12 @@ const std::vector<ShapeRelatedParams> IS3D_nightly = {
};
std::vector<fusingSpecificParams> fusingParamsSet3D_smoke {
// The following three patterns are convered by MLAS test
#ifndef OV_CPU_WITH_MLAS
emptyFusingSpec,
fusingBias,
fusingMultiplyPerChannel,
#endif
fusingFakeQuantizePerChannel,
fusingScaleShiftAndFakeQuantizePerChannel,
};
@ -516,6 +579,20 @@ const auto testParams3DBF16_smoke = ::testing::Combine(fullyConnectedParams3DBF1
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D, MatMulLayerCPUTest, testParams3D_smoke, MatMulLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_BF16, MatMulLayerCPUTest, testParams3DBF16_smoke, MatMulLayerCPUTest::getTestCaseName);
#ifdef OV_CPU_WITH_MLAS
std::vector<fusingSpecificParams> fusingParamsSet3D_MLAS_smoke {
emptyFusingSpec,
fusingBias,
fusingMultiplyPerChannel
};
const auto testParams3D_MLAS_smoke = ::testing::Combine(fullyConnectedParams3D_smoke,
::testing::Values(MatMulNodeType::FullyConnected),
::testing::ValuesIn(fusingParamsSet3D_MLAS_smoke),
::testing::ValuesIn(filterSpecificParams_MLAS()));
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_MLAS, MatMulLayerCPUTest, testParams3D_MLAS_smoke, MatMulLayerCPUTest::getTestCaseName);
#endif
const auto fullyConnectedParams3D_nightly = ::testing::Combine(::testing::ValuesIn(IS3D_nightly),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),

View File

@ -108,17 +108,28 @@ const std::vector<ReshapeFcSpecParams> reshFcParams = {
}
};
static std::vector<fusingSpecificParams> filterFusingParams(const std::vector<fusingSpecificParams>& orig) {
#ifdef OV_CPU_WITH_MLAS
return {emptyFusingSpec, fusingBias};
#else
return orig;
#endif
}
std::vector<fusingSpecificParams> fusingParamsSet {
emptyFusingSpec,
fusingBias,
fusingMultiplyPerChannel
};
#ifdef OV_CPU_WITH_MLAS
const auto gemmParam = CPUSpecificParams{{}, {}, {"gemm_mlas"}, "gemm_mlas"};
#else
const auto gemmParam = CPUSpecificParams{{}, {}, {"jit_gemm"}, "jit_gemm"};
#endif
const auto params = ::testing::Combine(
::testing::ValuesIn(reshFcParams),
::testing::ValuesIn(fusingParamsSet),
::testing::ValuesIn(filterFusingParams(fusingParamsSet)),
::testing::Values(gemmParam));
INSTANTIATE_TEST_SUITE_P(smoke_ReshapeFc, ReshapeFcCPUTest, params, ReshapeFcCPUTest::getTestCaseName);

View File

@ -24,6 +24,13 @@ else()
${CMAKE_CURRENT_SOURCE_DIR}/nodes/eltwise_node_test.cpp)
endif()
if (NOT ENABLE_MLAS_FOR_CPU)
set(MLAS_LIBRARY "")
list(APPEND EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST ${CMAKE_CURRENT_SOURCE_DIR}/gemm_api_test.cpp)
else()
set(MLAS_LIBRARY "mlas")
endif()
addIeTargetTest(
NAME ${TARGET_NAME}
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
@ -53,6 +60,7 @@ addIeTargetTest(
ngraphFunctions
snippetsNgraphFunctions
snippets_test_utils
${MLAS_LIBRARY}
ADD_CPPLINT
LABELS
CPU

View File

@ -0,0 +1,14 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <utility>
#include <gtest/gtest.h>
#include "mlas/sgemm.hpp"
// This test is used to test whether mlas gemm lib compiles successfully
TEST(GemmTests, getPackedSize) {
int N = 51864;
int K = 384;
ASSERT_NO_THROW(ov::intel_cpu::mlas_sgemm_pack_get_size(N, K));
}

View File

@ -131,4 +131,9 @@ function(ov_add_onednn)
endif()
endfunction()
if(ENABLE_MLAS_FOR_CPU)
add_subdirectory(mlas)
ov_install_static_lib(mlas cpu)
endif()
ov_add_onednn()

@ -0,0 +1 @@
Subproject commit b9414454951f2ef7c6e491f47b46318320d65e5c