[CPU] FullyConnected op acceleration with 8bit weights decompression (#18915)

This commit is contained in:
Gorokhov Dmitriy 2023-08-04 09:34:48 +04:00 committed by GitHub
parent 2c3e17ef2a
commit 80a807e26c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 12 deletions

View File

@ -251,5 +251,33 @@ void DnnlPostOpsComposer::appendClip(const std::vector<float>& low, const std::v
} }
} }
void DnnlPostOpsComposer::appendDecompressionScales(const std::vector<float>& scales) {
if (scales.empty())
return;
int mask = scales.size() > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask);
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({scales.size()}));
auto mem = std::make_shared<Memory>(engine, memoryDesc);
memcpy(mem->getData(), scales.data(), scales.size() * sizeof(float));
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = mem;
}
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const std::vector<float>& zero_points) {
if (zero_points.empty())
return;
int mask = zero_points.size() > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights zero points mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_zero_points_mask(DNNL_ARG_WEIGHTS, mask);
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({zero_points.size()}));
auto mem = std::make_shared<Memory>(engine, memoryDesc);
memcpy(mem->getData(), zero_points.data(), zero_points.size() * sizeof(float));
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = mem;
}
} // namespace intel_cpu } // namespace intel_cpu
} // namespace ov } // namespace ov

View File

@ -42,6 +42,9 @@ public:
bool appendLinear(const std::vector<float>& scale, const std::vector<float>& shift, bool isLastPostOp, bool allowBinary = true); bool appendLinear(const std::vector<float>& scale, const std::vector<float>& shift, bool isLastPostOp, bool allowBinary = true);
void appendClip(const std::vector<float>& low, const std::vector<float>& high); void appendClip(const std::vector<float>& low, const std::vector<float>& high);
void appendDecompressionScales(const std::vector<float>& scales);
void appendDecompressionZeroPoints(const std::vector<float>& zero_points);
const VectorDims& getOutputDims() { const VectorDims& getOutputDims() {
return outputDims; return outputDims;
} }

View File

@ -75,6 +75,10 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
FuseConvMatmulFCDeconvAndDQScales(graph); FuseConvMatmulFCDeconvAndDQScales(graph);
graph.RemoveDroppedNodes(); graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseFCAndWeightsDecompression");
FuseFCAndWeightsDecompression(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseConvolutionAndBias"); OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseConvolutionAndBias");
FuseConvolutionMatMulDeconvAndBias(graph); FuseConvolutionMatMulDeconvAndBias(graph);
graph.RemoveDroppedNodes(); graph.RemoveDroppedNodes();
@ -283,6 +287,9 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
return node->getType() == expectedType && node->getChildEdges().size() == 1; return node->getType() == expectedType && node->getChildEdges().size() == 1;
}; };
if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx2))
return;
auto& graphNodes = graph.GetNodes(); auto& graphNodes = graph.GetNodes();
for (size_t i = 0; i < graphNodes.size(); i++) { for (size_t i = 0; i < graphNodes.size(); i++) {
const auto fcNode = dynamic_cast<node::FullyConnected*>(graphNodes[i].get()); const auto fcNode = dynamic_cast<node::FullyConnected*>(graphNodes[i].get());
@ -323,6 +330,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
continue; continue;
// Precision limitations // Precision limitations
if (fcNode->getOriginalInputPrecisionAtPort(0) != Precision::FP32)
continue;
if (multiplyConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32) if (multiplyConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
continue; continue;
if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end()) if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end())

View File

@ -192,10 +192,6 @@ void FullyConnected::getSupportedDescriptors() {
if (getChildEdges().empty()) if (getChildEdges().empty())
IE_THROW()<< errorPrefix << " has incorrect number of output edges"; IE_THROW()<< errorPrefix << " has incorrect number of output edges";
withBiases = getOriginalInputsNumber() == 3;
useSparseWeights = useSparseWeightsDecompression();
auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID)); auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID));
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID)); outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID));
@ -203,6 +199,13 @@ void FullyConnected::getSupportedDescriptors() {
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0)); outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0));
} }
auto weightsDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(WEIGHTS_ID)); auto weightsDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(WEIGHTS_ID));
withBiases = getOriginalInputsNumber() == 3;
useSparseWeights = useSparseWeightsDecompression();
useWeightsDecompressionImpl = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
inputDataType == memory::data_type::f32 && weightsDataType == memory::data_type::u8;
// revert back outputDataType on special cases // revert back outputDataType on special cases
if (inputDataType == memory::data_type::f32) { if (inputDataType == memory::data_type::f32) {
// oneDNN only support f32 output when input is f32, even if FQ is fused // oneDNN only support f32 output when input is f32, even if FQ is fused
@ -240,7 +243,7 @@ void FullyConnected::getSupportedDescriptors() {
#if defined(OV_CPU_WITH_MLAS) && (defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)) #if defined(OV_CPU_WITH_MLAS) && (defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64))
// MLAS doesn't support post-ops fusing and only supports FP32. INT8 is not enabled yet // MLAS doesn't support post-ops fusing and only supports FP32. INT8 is not enabled yet
// Disable MLAS when FC could fuse post-ops // Disable MLAS when FC could fuse post-ops
useMlas = !useSparseWeights && useMlas = !useSparseWeights && !useWeightsDecompressionImpl &&
(inputDataType == memory::data_type::f32 && weightsDataType == memory::data_type::f32) && (inputDataType == memory::data_type::f32 && weightsDataType == memory::data_type::f32) &&
fusedWith.empty(); fusedWith.empty();
auto wgtDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims(); auto wgtDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
@ -587,6 +590,11 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, canBeExecutedInInt8(), DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, canBeExecutedInInt8(),
1 << 0, getDQScales(), withBiases); 1 << 0, getDQScales(), withBiases);
if (!decompressionMultiply.empty())
dnnlpoc.appendDecompressionScales(decompressionMultiply);
if (!decompressionSubtract.empty())
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtract);
for (size_t i = 0; i < fusedWith.size(); ++i) { for (size_t i = 0; i < fusedWith.size(); ++i) {
auto& node = fusedWith[i]; auto& node = fusedWith[i];
bool isLastPostOp = (i == (fusedWith.size() - 1)); bool isLastPostOp = (i == (fusedWith.size() - 1));
@ -673,11 +681,14 @@ void FullyConnected::createDescriptorInternal(const dnnl::memory::desc &inputDes
dnnl::memory::data_type wdt = indt; dnnl::memory::data_type wdt = indt;
dnnl::memory::data_type bdt = outdt; dnnl::memory::data_type bdt = outdt;
if (one_of(indt, dnnl::memory::data_type::bf16, dnnl::memory::data_type::f16)) { if (useWeightsDecompressionImpl) {
//oneDNN ARM InnerProduct primitive supports only identical in/out data types // Weights decompression case
wdt = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(WEIGHTS_ID));
} else if (one_of(indt, dnnl::memory::data_type::bf16, dnnl::memory::data_type::f16)) {
#if defined(OPENVINO_ARCH_X86_64) #if defined(OPENVINO_ARCH_X86_64)
bdt = dnnl::memory::data_type::f32; bdt = dnnl::memory::data_type::f32;
#else #else
// oneDNN ARM InnerProduct primitive supports only identical in/out data types
bdt = dnnl::memory::data_type::f16; bdt = dnnl::memory::data_type::f16;
#endif #endif
} else if (indt == dnnl::memory::data_type::u8 || indt == dnnl::memory::data_type::s8) { } else if (indt == dnnl::memory::data_type::u8 || indt == dnnl::memory::data_type::s8) {
@ -939,6 +950,9 @@ bool FullyConnected::canBeExecutedInConv1x1() const {
bool retVal = false; bool retVal = false;
const auto inRank = getInputShapeAtPort(DATA_ID).getRank(); const auto inRank = getInputShapeAtPort(DATA_ID).getRank();
const auto weightRank = getInputShapeAtPort(WEIGHTS_ID).getRank(); const auto weightRank = getInputShapeAtPort(WEIGHTS_ID).getRank();
if (useWeightsDecompressionImpl) {
return false;
}
// disable rank=4: // disable rank=4:
// if layout is nhwc: // if layout is nhwc:
// A matrix: N * IC * H * W --> N * (IC*H*W), the M, N', K of matrix multiply will be: // A matrix: N * IC * H * W --> N * (IC*H*W), the M, N', K of matrix multiply will be:

View File

@ -115,6 +115,7 @@ private:
void prepackMLASWeight(); void prepackMLASWeight();
#endif #endif
bool useWeightsDecompressionImpl = false;
std::vector<float> decompressionSubtract; std::vector<float> decompressionSubtract;
std::vector<float> decompressionMultiply; std::vector<float> decompressionMultiply;
}; };

View File

@ -188,8 +188,6 @@ std::vector<std::string> disabledTestPatterns() {
// New plugin API doesn't support changes of pre-processing // New plugin API doesn't support changes of pre-processing
R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInputInfo.*)", R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInputInfo.*)",
R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInferRequest.*)", R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInferRequest.*)",
// Issue: 113727
R"(.*MatMulCompressedWeights.*)",
// TODO: for 22.2 (CVS-68949) // TODO: for 22.2 (CVS-68949)
R"(.*smoke_AutoBatching_CPU/AutoBatching_Test_DetectionOutput.*)", R"(.*smoke_AutoBatching_CPU/AutoBatching_Test_DetectionOutput.*)",
}; };

View File

@ -189,7 +189,9 @@ protected:
} }
std::map<std::string, std::string> additional_config = std::get<5>(test_param); std::map<std::string, std::string> additional_config = std::get<5>(test_param);
const size_t expected_count = additional_config[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES ? 1 : 0; const size_t expected_count =
InferenceEngine::with_cpu_x86_avx2() &&
additional_config[PluginConfigParams::KEY_ENFORCE_BF16] != PluginConfigParams::YES ? 0 : 1;
CheckNumberOfNodesWithType(compiledModel, "Convert", expected_count); CheckNumberOfNodesWithType(compiledModel, "Convert", expected_count);
CheckNumberOfNodesWithType(compiledModel, "Eltwise", expected_count); CheckNumberOfNodesWithType(compiledModel, "Eltwise", expected_count);
CheckNumberOfNodesWithType(compiledModel, "Subgraph", 0); CheckNumberOfNodesWithType(compiledModel, "Subgraph", 0);
@ -220,6 +222,9 @@ const std::vector<std::vector<InputShape>> input_shapes_basic = {
{{{}, {{1, 4, 48}}}, {{}, {{48, 256}}}}, {{{}, {{1, 4, 48}}}, {{}, {{48, 256}}}},
{{{}, {{1, 4, 512}}}, {{}, {{512, 256}}}}, {{{}, {{1, 4, 512}}}, {{}, {{512, 256}}}},
{{{}, {{1, 16, 32}}}, {{}, {{32, 64}}}}, {{{}, {{1, 16, 32}}}, {{}, {{32, 64}}}},
{{{}, {{2, 4, 32}}}, {{}, {{32, 65}}}},
{{{}, {{11, 339, 377}}}, {{}, {{377, 335}}}},
{{{}, {{3, 12, 768}}}, {{}, {{768, 1024}}}},
}; };
const std::vector<fusingSpecificParams> fusingParamsSet { const std::vector<fusingSpecificParams> fusingParamsSet {
emptyFusingSpec, emptyFusingSpec,

View File

@ -87,7 +87,7 @@ function(create_target_per_test_for_directory TEST_DIR TARGET_PREFIX)
endfunction() endfunction()
if(ENABLE_CPU_SPECIFIC_TARGET_PER_TEST) if(ENABLE_CPU_SPECIFIC_TARGET_PER_TEST)
create_target_per_test_for_directory(${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tests/src/arm ov_cpu_func_subgraph) create_target_per_test_for_directory(${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tests/src ov_cpu_func_subgraph)
create_target_per_test_for_directory(${CMAKE_CURRENT_SOURCE_DIR}/single_layer_tests ov_cpu_func_slt) create_target_per_test_for_directory(${CMAKE_CURRENT_SOURCE_DIR}/single_layer_tests ov_cpu_func_slt)
endif() endif()

@ -1 +1 @@
Subproject commit 2c0b3f2946185370b3ccbe4330398c15109e5ec4 Subproject commit ec7a051b3f4a9e65b22382f3d787e84ce74efe07