[CPU] FullyConnected acceleration with 4bit weights decompression (#20607)
This commit is contained in:
parent
00e2381d04
commit
63299ec217
@ -53,6 +53,8 @@ INFERENCE_ENGINE_1_0_DEPRECATED inline ::ngraph::element::Type convertPrecision(
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::boolean);
|
||||
case Precision::BIN:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u1);
|
||||
case Precision::NF4:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::nf4);
|
||||
case Precision::Q78:
|
||||
case Precision::MIXED:
|
||||
case Precision::CUSTOM:
|
||||
|
@ -21,19 +21,18 @@ namespace intel_cpu {
|
||||
uint8_t DnnlExtensionUtils::sizeOfDataType(dnnl::memory::data_type dataType) {
|
||||
switch (dataType) {
|
||||
case dnnl::memory::data_type::f32:
|
||||
return 4;
|
||||
case dnnl::memory::data_type::s32:
|
||||
return 4;
|
||||
case dnnl::memory::data_type::bf16:
|
||||
return 2;
|
||||
case dnnl::memory::data_type::s8:
|
||||
return 1;
|
||||
case dnnl::memory::data_type::u8:
|
||||
return 1;
|
||||
case dnnl::memory::data_type::bin:
|
||||
return 1;
|
||||
case dnnl::memory::data_type::f16:
|
||||
return 2;
|
||||
case dnnl::memory::data_type::s8:
|
||||
case dnnl::memory::data_type::u8:
|
||||
case dnnl::memory::data_type::bin:
|
||||
case dnnl::memory::data_type::nf4:
|
||||
case dnnl::memory::data_type::s4:
|
||||
case dnnl::memory::data_type::u4:
|
||||
return 1;
|
||||
case dnnl::memory::data_type::undef:
|
||||
return 0;
|
||||
default:
|
||||
@ -58,6 +57,12 @@ memory::data_type DnnlExtensionUtils::IEPrecisionToDataType(const InferenceEngin
|
||||
return memory::data_type::bin;
|
||||
case InferenceEngine::Precision::FP16:
|
||||
return memory::data_type::f16;
|
||||
case InferenceEngine::Precision::NF4:
|
||||
return memory::data_type::nf4;
|
||||
case InferenceEngine::Precision::I4:
|
||||
return memory::data_type::s4;
|
||||
case InferenceEngine::Precision::U4:
|
||||
return memory::data_type::u4;
|
||||
case InferenceEngine::Precision::UNSPECIFIED:
|
||||
return memory::data_type::undef;
|
||||
default: {
|
||||
@ -82,6 +87,12 @@ InferenceEngine::Precision DnnlExtensionUtils::DataTypeToIEPrecision(memory::dat
|
||||
return InferenceEngine::Precision::BIN;
|
||||
case memory::data_type::f16:
|
||||
return InferenceEngine::Precision::FP16;
|
||||
case memory::data_type::nf4:
|
||||
return InferenceEngine::Precision::NF4;
|
||||
case memory::data_type::s4:
|
||||
return InferenceEngine::Precision::I4;
|
||||
case memory::data_type::u4:
|
||||
return InferenceEngine::Precision::U4;
|
||||
case memory::data_type::undef:
|
||||
return InferenceEngine::Precision::UNSPECIFIED;
|
||||
default: {
|
||||
|
@ -251,48 +251,58 @@ void DnnlPostOpsComposer::appendClip(const std::vector<float>& low, const std::v
|
||||
}
|
||||
}
|
||||
|
||||
MemoryPtr DnnlPostOpsComposer::prepackDecompressionParams(const MemoryCPtr& params_ptr, size_t icBlock) {
|
||||
// Prepacking params from [oc] to [oc, icBlock] layout, where for each icBlock corresponding parameter is duplicated
|
||||
MemoryPtr DnnlPostOpsComposer::prepackDecompressionParams(const MemoryCPtr& params_ptr, bool needTranspose) {
|
||||
const auto shape = params_ptr->getShape().getStaticDims();
|
||||
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({icBlock * elements_count}));
|
||||
auto mem = std::make_shared<Memory>(engine, memoryDesc);
|
||||
size_t dstIdx = 0;
|
||||
auto decomp_scales_data = static_cast<float*>(params_ptr->getData());
|
||||
auto decomp_scales_buf = static_cast<float*>(mem->getData());
|
||||
for (size_t oc = 0; oc < elements_count; oc++) {
|
||||
for (size_t intIdx = 0; intIdx < icBlock; intIdx++) {
|
||||
decomp_scales_buf[dstIdx] = decomp_scales_data[oc];
|
||||
MemoryPtr mem;
|
||||
|
||||
auto params_data = static_cast<float*>(params_ptr->getData());
|
||||
|
||||
if (needTranspose) {
|
||||
VectorDims dnnlShape = {shape[0], shape[1]};
|
||||
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape(dnnlShape));
|
||||
mem = std::make_shared<Memory>(engine, memoryDesc);
|
||||
auto memory_buf = static_cast<float*>(mem->getData());
|
||||
|
||||
// oi -> io
|
||||
for (size_t oc = 0; oc < dnnlShape[0]; oc++) {
|
||||
for (size_t ic = 0; ic < dnnlShape[1]; ic++) {
|
||||
memory_buf[ic * dnnlShape[0] + oc] = params_data[oc * dnnlShape[1] + ic];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
VectorDims dnnlShape = {shape[shape.size() - 1], shape[0]};
|
||||
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape(dnnlShape));
|
||||
mem = std::make_shared<Memory>(engine, memoryDesc);
|
||||
auto memory_buf = static_cast<float*>(mem->getData());
|
||||
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
// io -> io
|
||||
size_t dstIdx = 0;
|
||||
for (size_t oc = 0; oc < elements_count; oc++) {
|
||||
memory_buf[dstIdx] = params_data[oc];
|
||||
dstIdx++;
|
||||
}
|
||||
}
|
||||
|
||||
return mem;
|
||||
}
|
||||
|
||||
void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, size_t icBlock) {
|
||||
void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose) {
|
||||
if (scales_ptr == nullptr)
|
||||
return;
|
||||
|
||||
const auto shape = scales_ptr->getShape().getStaticDims();
|
||||
const auto elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
int mask = elements_count > 1 ? weightScaleMaskPerChannel : 0;
|
||||
DEBUG_LOG("Set weights scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask);
|
||||
|
||||
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(scales_ptr, icBlock);
|
||||
auto scalesMem = prepackDecompressionParams(scales_ptr, needTranspose);
|
||||
attr.set_scales_dims(DNNL_ARG_WEIGHTS, DnnlExtensionUtils::convertToDnnlDims(scalesMem->getStaticDims()));
|
||||
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = scalesMem;
|
||||
}
|
||||
|
||||
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, size_t icBlock) {
|
||||
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose) {
|
||||
if (zero_points_ptr == nullptr)
|
||||
return;
|
||||
|
||||
const auto shape = zero_points_ptr->getShape().getStaticDims();
|
||||
const auto elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
int mask = elements_count > 1 ? weightScaleMaskPerChannel : 0;
|
||||
DEBUG_LOG("Set weights zero points mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
|
||||
attr.set_zero_points_mask(DNNL_ARG_WEIGHTS, mask);
|
||||
|
||||
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(zero_points_ptr, icBlock);
|
||||
auto zeroPointsMem = prepackDecompressionParams(zero_points_ptr, needTranspose);
|
||||
attr.set_zero_points_dims(DNNL_ARG_WEIGHTS, DnnlExtensionUtils::convertToDnnlDims(zeroPointsMem->getStaticDims()));
|
||||
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zeroPointsMem;
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -42,8 +42,8 @@ public:
|
||||
bool appendLinear(const std::vector<float>& scale, const std::vector<float>& shift, bool isLastPostOp, bool allowBinary = true);
|
||||
void appendClip(const std::vector<float>& low, const std::vector<float>& high);
|
||||
|
||||
void appendDecompressionScales(const MemoryCPtr& scales_ptr, size_t icBlock);
|
||||
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, size_t icBlock);
|
||||
void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose);
|
||||
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose);
|
||||
|
||||
const VectorDims& getOutputDims() {
|
||||
return outputDims;
|
||||
@ -69,7 +69,7 @@ private:
|
||||
|
||||
void updateWeiScales();
|
||||
void updateDestScales();
|
||||
MemoryPtr prepackDecompressionParams(const MemoryCPtr& params_ptr, size_t icBlock);
|
||||
MemoryPtr prepackDecompressionParams(const MemoryCPtr& params_ptr, bool needTranspose);
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -286,7 +286,8 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
|
||||
}
|
||||
|
||||
void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8, InferenceEngine::Precision::NF4};
|
||||
std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8, InferenceEngine::Precision::NF4,
|
||||
InferenceEngine::Precision::U4, InferenceEngine::Precision::I4};
|
||||
const std::set<InferenceEngine::Precision> supportedDataPrecisions{InferenceEngine::Precision::FP32, InferenceEngine::Precision::BF16};
|
||||
auto expectedNode = [](NodePtr node, Type expectedType) {
|
||||
return node->getType() == expectedType && node->getChildEdges().size() == 1;
|
||||
@ -335,7 +336,28 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto convertNode = withSubtract ? subtractNode->getParentEdgesAtPort(0)[0]->getParent() : mulParent;
|
||||
const bool withPowerStatic = mulParent->getAlgorithm() == Algorithm::EltwisePowerStatic;
|
||||
NodePtr powerStaticNode;
|
||||
if (withPowerStatic) {
|
||||
powerStaticNode = mulParent;
|
||||
if (auto *eltwiseNode = dynamic_cast<node::Eltwise *>(powerStaticNode.get())) {
|
||||
if (eltwiseNode->getAlpha() != 1 || eltwiseNode->getBeta() != 1)
|
||||
continue;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Both operations fallbacks on IP zero-point attribute and cannot be combined
|
||||
if (withSubtract && withPowerStatic)
|
||||
continue;
|
||||
|
||||
auto convertNode = mulParent;
|
||||
if (withSubtract)
|
||||
convertNode = subtractNode->getParentEdgesAtPort(0)[0]->getParent();
|
||||
if (withPowerStatic)
|
||||
convertNode = powerStaticNode->getParentEdgesAtPort(0)[0]->getParent();
|
||||
|
||||
if (!expectedNode(convertNode, Type::Convert))
|
||||
continue;
|
||||
const auto weightsNode = convertNode->getParentEdgesAtPort(0)[0]->getParent();
|
||||
@ -347,6 +369,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
continue;
|
||||
if (withSubtract && subtractConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
|
||||
continue;
|
||||
if (withPowerStatic && powerStaticNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
|
||||
continue;
|
||||
if (supportedDataPrecisions.find(fcNode->getOriginalInputPrecisionAtPort(0)) == supportedDataPrecisions.end())
|
||||
continue;
|
||||
if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end())
|
||||
@ -361,6 +385,7 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
|
||||
VectorDims decompressionConstShape;
|
||||
const auto fcInputWeightsShape = fcNode->getInputShapeAtPort(1);
|
||||
int groupNum = 1;
|
||||
// Ordinary case: one decompression group
|
||||
if (fcInputWeightsShape.getRank() == weightsShape.getRank()) {
|
||||
const auto& out_channels = fcInputWeightsShape.getDims()[0];
|
||||
@ -377,6 +402,7 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
const auto& O = withTranspose ? *weights_dims.rbegin() : *(weights_dims.rbegin() + 2);
|
||||
// Group decompression is applied by O and N dims
|
||||
decompressionConstShape = withTranspose ? VectorDims{N, 1, O} : VectorDims{O, N, 1};
|
||||
groupNum = N;
|
||||
}
|
||||
if (multiplyConstNode->getOutputShapeAtPort(0).getDims() != decompressionConstShape)
|
||||
continue;
|
||||
@ -384,7 +410,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
continue;
|
||||
|
||||
// HW specific shape limitations
|
||||
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx)) {
|
||||
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx) &&
|
||||
fcNode->getOriginalInputPrecisionAtPort(0) == InferenceEngine::Precision::BF16) {
|
||||
// OneDNN AMX IP implementation has limited shapes support due to performance considerations. As a current solution conditions below are copied
|
||||
// from OneDNN to make sure correct IP impl will be used since fallback one doesn't support weights decompression feature.
|
||||
size_t OC = fcInputWeightsShape.getDims()[0];
|
||||
@ -398,10 +425,38 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t IC = fcInputWeightsShape.getDims()[1];
|
||||
// OneDNN IP primitive provides limited decompression params support
|
||||
if (IC % groupNum != 0 || IC / groupNum < 4) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fusion processing
|
||||
fcNode->fuseDecompressionMultiply(multiplyConstNode);
|
||||
if (withSubtract)
|
||||
fcNode->fuseDecompressionSubtract(subtractConstNode);
|
||||
auto *multiplyInputNode = dynamic_cast<node::Input *>(multiplyConstNode.get());
|
||||
if (!multiplyInputNode) {
|
||||
IE_THROW() << "Cannot cast " << multiplyInputNode->getName() << " to Input node";
|
||||
}
|
||||
fcNode->fuseDecompressionMultiply(multiplyInputNode->getMemoryPtr());
|
||||
|
||||
if (withSubtract) {
|
||||
auto *subtractInputNode = dynamic_cast<node::Input *>(subtractConstNode.get());
|
||||
if (!subtractInputNode) {
|
||||
IE_THROW() << "Cannot cast " << subtractInputNode->getName() << " to Input node";
|
||||
}
|
||||
fcNode->fuseDecompressionSubtract(subtractInputNode->getMemoryPtr());
|
||||
}
|
||||
if (withPowerStatic) {
|
||||
auto *eltwiseNode = dynamic_cast<node::Eltwise *>(powerStaticNode.get());
|
||||
if (!eltwiseNode) {
|
||||
IE_THROW() << "Cannot cast " << eltwiseNode->getName() << " to Eltwise node";
|
||||
}
|
||||
|
||||
VectorDims memoryDims(decompressionConstShape.size(), 1);
|
||||
CpuBlockedMemoryDesc memoryDesc(Precision::FP32, Shape(memoryDims));
|
||||
auto memory = std::make_shared<Memory>(graph.getEngine(), memoryDesc, nullptr, false);
|
||||
(static_cast<float *>(memory->getData()))[0] = -1.f * eltwiseNode->getGamma();
|
||||
fcNode->fuseDecompressionSubtract(memory);
|
||||
}
|
||||
|
||||
fcNode->addOriginalLayer(multiplyNode->getOriginalLayers());
|
||||
fcNode->addOriginalLayer(convertNode->getOriginalLayers());
|
||||
@ -411,12 +466,18 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
|
||||
auto subtractConstEdge = subtractConstNode->getChildEdges()[0].lock();
|
||||
graph.RemoveEdge(subtractConstEdge);
|
||||
}
|
||||
if (withPowerStatic) {
|
||||
fcNode->addOriginalLayer(powerStaticNode->getOriginalLayers());
|
||||
}
|
||||
|
||||
auto multiplyConstEdge = multiplyConstNode->getChildEdges()[0].lock();
|
||||
graph.RemoveEdge(multiplyConstEdge);
|
||||
|
||||
graph.DropNode(convertNode);
|
||||
if (withSubtract)
|
||||
graph.DropNode(subtractNode);
|
||||
if (withPowerStatic)
|
||||
graph.DropNode(powerStaticNode);
|
||||
graph.DropNode(multiplyNode);
|
||||
|
||||
const auto& weightsPrecision = weightsNode->getOriginalOutputPrecisionAtPort(0);
|
||||
|
@ -208,7 +208,8 @@ void FullyConnected::getSupportedDescriptors() {
|
||||
useSparseWeights = useSparseWeightsDecompression();
|
||||
useWeightsDecompressionImpl = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
|
||||
one_of(inputDataType, memory::data_type::f32, memory::data_type::bf16) &&
|
||||
weightsDataType == memory::data_type::u8;
|
||||
one_of(weightsDataType, memory::data_type::u8, memory::data_type::nf4,
|
||||
memory::data_type::u4, memory::data_type::s4);
|
||||
|
||||
// revert back outputDataType on special cases
|
||||
if (inputDataType == memory::data_type::f32) {
|
||||
@ -724,15 +725,10 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
|
||||
NodeDesc *selected_pd = getSelectedPrimitiveDescriptor();
|
||||
if (selected_pd == nullptr)
|
||||
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
|
||||
// OneDNN API doesn't provide an abilitiy to query optimal layout for runtime attributes
|
||||
// As workaround we assume that all AMX IP implementations use equal internal IC block size for weights layout
|
||||
// and prepack runtime attributes accordingly for better performance
|
||||
bool withAMX = selected_pd->getImplementationType() & impl_desc_type::amx;
|
||||
int icBlock = withAMX ? 2 : 1;
|
||||
if (decompressionMultiplyPtr)
|
||||
dnnlpoc.appendDecompressionScales(decompressionMultiplyPtr, icBlock);
|
||||
dnnlpoc.appendDecompressionScales(decompressionMultiplyPtr, !weightsNonTransposed);
|
||||
if (decompressionSubtractPtr)
|
||||
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr, icBlock);
|
||||
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr, !weightsNonTransposed);
|
||||
|
||||
for (size_t i = 0; i < fusedWith.size(); ++i) {
|
||||
auto& node = fusedWith[i];
|
||||
@ -1132,30 +1128,25 @@ bool FullyConnected::useSparseWeightsDecompression() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void FullyConnected::fuseDecompressionMultiply(const NodePtr& constData) {
|
||||
fuseDecompressionConstant(constData, decompressionMultiplyPtr);
|
||||
void FullyConnected::fuseDecompressionMultiply(const MemoryCPtr& memory) {
|
||||
fuseDecompressionConstant(memory, decompressionMultiplyPtr);
|
||||
}
|
||||
|
||||
void FullyConnected::fuseDecompressionSubtract(const NodePtr& constData) {
|
||||
fuseDecompressionConstant(constData, decompressionSubtractPtr);
|
||||
void FullyConnected::fuseDecompressionSubtract(const MemoryCPtr& memory) {
|
||||
fuseDecompressionConstant(memory, decompressionSubtractPtr);
|
||||
}
|
||||
|
||||
void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr) {
|
||||
auto *constInputNode = dynamic_cast<node::Input *>(constData.get());
|
||||
if (!constInputNode) {
|
||||
IE_THROW() << "Cannot cast " << constData->getName() << " to Input";
|
||||
}
|
||||
void FullyConnected::fuseDecompressionConstant(const MemoryCPtr& memory, MemoryCPtr& decompressionValuesPtr) {
|
||||
const auto decompression_prc = InferenceEngine::Precision::FP32;
|
||||
if (constInputNode->getOriginalOutputPrecisionAtPort(0) == decompression_prc) {
|
||||
decompressionValuesPtr = constInputNode->getMemoryPtr();
|
||||
if (memory->getDesc().getPrecision() == decompression_prc) {
|
||||
decompressionValuesPtr = memory;
|
||||
} else {
|
||||
const auto constBlob = constInputNode->getMemoryPtr();
|
||||
DnnlBlockedMemoryDesc memoryDesc(decompression_prc, constBlob->getShape());
|
||||
DnnlBlockedMemoryDesc memoryDesc(decompression_prc, memory->getShape());
|
||||
decompressionValuesPtr = std::make_shared<Memory>(getEngine(), memoryDesc, nullptr, false);
|
||||
const auto elementsCount = constBlob->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
|
||||
cpu_convert(constBlob->getData(),
|
||||
const auto elementsCount = memory->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
|
||||
cpu_convert(memory->getData(),
|
||||
decompressionValuesPtr->getData(),
|
||||
DnnlExtensionUtils::DataTypeToIEPrecision(constBlob->getDataType()),
|
||||
DnnlExtensionUtils::DataTypeToIEPrecision(memory->getDataType()),
|
||||
Precision::FP32,
|
||||
elementsCount);
|
||||
}
|
||||
|
@ -60,8 +60,8 @@ public:
|
||||
this->weightsNonTransposed = weightsNonTransposed;
|
||||
}
|
||||
|
||||
void fuseDecompressionMultiply(const NodePtr& constData);
|
||||
void fuseDecompressionSubtract(const NodePtr& constData);
|
||||
void fuseDecompressionMultiply(const MemoryCPtr& memory);
|
||||
void fuseDecompressionSubtract(const MemoryCPtr& memory);
|
||||
|
||||
private:
|
||||
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
|
||||
@ -99,7 +99,7 @@ private:
|
||||
const dnnl::engine& engine);
|
||||
|
||||
bool canBeExecutedInConv1x1() const;
|
||||
void fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr);
|
||||
void fuseDecompressionConstant(const MemoryCPtr& memory, MemoryCPtr& decompressionValuesPtr);
|
||||
|
||||
// sparse weights
|
||||
bool useSparseWeights = false;
|
||||
|
@ -201,11 +201,16 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
} else {
|
||||
// We need to fuse Transpose to MatMul to have a simpler callback for the next transformation
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeMatMul);
|
||||
const ov::element::TypeVector decompression_precisions{
|
||||
ov::element::u8,
|
||||
// TODO: Uncomment when group decompression is supported
|
||||
// ov::element::nf4
|
||||
ov::element::TypeVector decompression_precisions{
|
||||
ov::element::u8
|
||||
};
|
||||
// We don't have BF16/FP16 FullyConnected kernels to work with 4bits compressed weights
|
||||
// Convert node doesn't support 4bit precisions -> fallback on constant folding
|
||||
if (inferencePrecision == ov::element::f32) {
|
||||
decompression_precisions.push_back(ov::element::u4);
|
||||
decompression_precisions.push_back(ov::element::i4);
|
||||
decompression_precisions.push_back(ov::element::nf4);
|
||||
}
|
||||
// MarkDequantizationSubgraph is used even in non-LPT pipeline on X64 platforms
|
||||
// in order to keep compressed MatMul weights with decompression operations as is
|
||||
CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, true);
|
||||
@ -223,15 +228,13 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
|
||||
if (ov::is_type<ov::opset1::MatMul>(consumer)) {
|
||||
return false;
|
||||
} else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
|
||||
consumer = get_single_consumer(consumer);
|
||||
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// TODO: Uncomment when group decompression is supported
|
||||
// if (ov::is_type<ov::opset1::Reshape>(consumer)) {
|
||||
// consumer = get_single_consumer(consumer);
|
||||
// if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
|
||||
// return false;
|
||||
// }
|
||||
// }
|
||||
if (ov::is_type<ov::opset1::Convert>(consumer)) {
|
||||
if (consumer != nullptr && ov::is_type<ov::opset1::Convert>(consumer)) {
|
||||
consumer = get_single_consumer(consumer);
|
||||
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
|
||||
return false;
|
||||
|
@ -142,7 +142,8 @@ protected:
|
||||
transformed_weights_shape[in_channel_idx] = weights_shape[0] / group_size;
|
||||
transformed_weights_shape.insert(transformed_weights_shape.begin() + in_channel_idx + 1, group_size);
|
||||
}
|
||||
auto weights = ngraph::builder::makeConstant<uint8_t>(weights_precision, transformed_weights_shape, {}, true);
|
||||
|
||||
auto weights = ngraph::builder::makeConstant<int8_t>(weights_precision, transformed_weights_shape, {}, true, 7);
|
||||
weights->set_friendly_name("Compressed_weights");
|
||||
auto weights_convert = std::make_shared<ngraph::opset1::Convert>(weights, decompression_precision);
|
||||
|
||||
@ -164,7 +165,7 @@ protected:
|
||||
if (reshape_on_decompression_constant)
|
||||
scaleshift_const_shape.erase(std::remove(scaleshift_const_shape.begin(), scaleshift_const_shape.end(), 1), scaleshift_const_shape.end());
|
||||
if (add_subtract) {
|
||||
auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true);
|
||||
auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true, 7);
|
||||
std::shared_ptr<ov::Node> shift_convert = std::make_shared<ngraph::opset1::Convert>(shift_const, decompression_precision);
|
||||
if (reshape_on_decompression_constant) {
|
||||
auto shift_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape);
|
||||
@ -268,10 +269,7 @@ protected:
|
||||
void checkResults() {
|
||||
const auto& test_param = GetParam();
|
||||
const auto& weights_precision = std::get<1>(test_param);
|
||||
// TODO: remove this condition when group decompression is supported
|
||||
if (weights_precision == ov::element::nf4 || std::get<0>(test_param).weights_group_size != -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool weights_found = false;
|
||||
for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) {
|
||||
if (n->get_friendly_name() == "Compressed_weights") {
|
||||
@ -301,48 +299,37 @@ std::vector<std::map<std::string, std::string>> filterAdditionalConfigBasic() {
|
||||
std::vector<std::map<std::string, std::string>> additional_config = {CPUTestUtils::cpuEmptyPluginConfig};
|
||||
return additional_config;
|
||||
}
|
||||
std::vector<std::map<std::string, std::string>> filterAdditionalConfigBig() {
|
||||
std::vector<std::map<std::string, std::string>> additional_config = {CPUTestUtils::cpuEmptyPluginConfig};
|
||||
std::vector<std::map<std::string, std::string>> filterAdditionalConfigAMX() {
|
||||
std::vector<std::map<std::string, std::string>> additional_config = {};
|
||||
if (with_cpu_x86_avx512_core_amx())
|
||||
additional_config.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}});
|
||||
return additional_config;
|
||||
}
|
||||
|
||||
bool shouldUseDecompressionKernelBig() {
|
||||
// No decompression support on non-avx systems
|
||||
if (!with_cpu_x86_avx2())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool shouldUseDecompressionKernelBasic() {
|
||||
// AMX decompression support has shape limitations
|
||||
if (with_cpu_x86_avx512_core_amx())
|
||||
return false;
|
||||
|
||||
return shouldUseDecompressionKernelBig();
|
||||
}
|
||||
|
||||
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8, ov::element::nf4};
|
||||
const std::vector<ov::test::ElementType> decompression_precisions = {ov::element::f32};
|
||||
const std::vector<ov::test::ElementType> weights_precisions_basic = {ov::element::u8,
|
||||
ov::element::u4,
|
||||
ov::element::i4,
|
||||
ov::element::nf4};
|
||||
const std::vector<ov::test::ElementType> weights_precisions_amx = {ov::element::u8};
|
||||
|
||||
const std::vector<ShapeParams> input_shapes_basic = {
|
||||
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {16, 32}},
|
||||
{{{}, {{1, 4, 16}}}, {16, 32}, 2ul},
|
||||
{{{}, {{1, 8, 16}}}, {16, 32}, 4ul},
|
||||
{{{}, {{1, 4, 16}}}, {1, 16, 32}},
|
||||
{{{}, {{10, 40, 496}}}, {1, 496, 240}},
|
||||
{{{}, {{1, 4, 48}}}, {48, 256}},
|
||||
{{{}, {{11, 339, 377}}}, {377, 335}},
|
||||
};
|
||||
const std::vector<ShapeParams> input_shapes_big = {
|
||||
{{{}, {{1, 11, 154}}}, {154, 77}, 154ul},
|
||||
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
|
||||
};
|
||||
const std::vector<ShapeParams> input_shapes_amx = {
|
||||
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
|
||||
{{{-1, 1, 4096}, {{1, 1, 4096}}}, {4096, 3840}, 128ul},
|
||||
{{{}, {{1, 4, 32}}}, {32, 256}},
|
||||
{{{}, {{1, 4, 512}}}, {512, 256}},
|
||||
{{{}, {{1, 16, 32}}}, {32, 64}},
|
||||
{{{}, {{2, 4, 32}}}, {32, 65}},
|
||||
{{{}, {{3, 12, 768}}}, {768, 1024}},
|
||||
{{{}, {{11, 339, 577}}}, {577, 335}},
|
||||
{{{}, {{1, 1, 256}}}, {256, 128}, 64ul},
|
||||
};
|
||||
const std::vector<fusingSpecificParams> fusingParamsSet {
|
||||
emptyFusingSpec,
|
||||
@ -352,35 +339,36 @@ const std::vector<fusingSpecificParams> fusingParamsSet {
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
|
||||
MatmulWeightsDecompression,
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
|
||||
::testing::ValuesIn(weights_precisions),
|
||||
::testing::ValuesIn(weights_precisions_basic),
|
||||
::testing::ValuesIn(decompression_precisions),
|
||||
::testing::Values(true),
|
||||
::testing::Values(true),
|
||||
::testing::Values(true),
|
||||
::testing::ValuesIn(filterAdditionalConfigBasic()),
|
||||
::testing::ValuesIn(fusingParamsSet),
|
||||
::testing::Values(shouldUseDecompressionKernelBasic())),
|
||||
::testing::Values(true)),
|
||||
MatmulWeightsDecompression::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_big,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_amx,
|
||||
MatmulWeightsDecompression,
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes_big),
|
||||
::testing::ValuesIn(weights_precisions),
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes_amx),
|
||||
::testing::ValuesIn(weights_precisions_amx),
|
||||
::testing::ValuesIn(decompression_precisions),
|
||||
::testing::Values(true),
|
||||
::testing::Values(true),
|
||||
::testing::Values(true),
|
||||
::testing::ValuesIn(filterAdditionalConfigBig()),
|
||||
::testing::ValuesIn(filterAdditionalConfigAMX()),
|
||||
::testing::ValuesIn(fusingParamsSet),
|
||||
::testing::Values(shouldUseDecompressionKernelBig())),
|
||||
::testing::Values(true)),
|
||||
MatmulWeightsDecompression::getTestCaseName);
|
||||
|
||||
const std::vector<ShapeParams> input_shapes_corner_cases_basic = {
|
||||
{{{-1, -1, -1}, {{1, 4, 16}}}, {1, 16, 32}},
|
||||
{{{-1, -1, -1}, {{1, 4, 16}}}, {16, 32}},
|
||||
{{{-1, -1, -1}, {{1, 4, 16}}}, {16, 32}, 4ul},
|
||||
{{{-1, -1, -1}, {{1, 1, 4096}}}, {4096, 4096}, 128ul},
|
||||
};
|
||||
const std::vector<ShapeParams> input_shapes_corner_cases_big = {
|
||||
const std::vector<ShapeParams> input_shapes_corner_cases_amx = {
|
||||
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
|
||||
{{{-1, -1, -1}, {{1, 1, 4096}}}, {4096, 4096}, 128ul},
|
||||
};
|
||||
@ -393,27 +381,27 @@ const std::vector<ov::test::ElementType> decompression_precisions_corner_cases =
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_basic,
|
||||
MatmulWeightsDecompression,
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_basic),
|
||||
::testing::ValuesIn(weights_precisions),
|
||||
::testing::ValuesIn(weights_precisions_basic),
|
||||
::testing::ValuesIn(decompression_precisions_corner_cases),
|
||||
::testing::ValuesIn(transpose_weights),
|
||||
::testing::ValuesIn(add_decompression_sub),
|
||||
::testing::ValuesIn(reshape_on_decompression),
|
||||
::testing::ValuesIn(filterAdditionalConfigBasic()),
|
||||
::testing::Values(emptyFusingSpec),
|
||||
::testing::Values(shouldUseDecompressionKernelBasic())),
|
||||
::testing::Values(true)),
|
||||
MatmulWeightsDecompression::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_big,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_amx,
|
||||
MatmulWeightsDecompression,
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_big),
|
||||
::testing::ValuesIn(weights_precisions),
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_amx),
|
||||
::testing::ValuesIn(weights_precisions_amx),
|
||||
::testing::ValuesIn(decompression_precisions_corner_cases),
|
||||
::testing::ValuesIn(transpose_weights),
|
||||
::testing::ValuesIn(add_decompression_sub),
|
||||
::testing::ValuesIn(reshape_on_decompression),
|
||||
::testing::ValuesIn(filterAdditionalConfigBig()),
|
||||
::testing::ValuesIn(filterAdditionalConfigAMX()),
|
||||
::testing::Values(emptyFusingSpec),
|
||||
::testing::Values(shouldUseDecompressionKernelBig())),
|
||||
::testing::Values(true)),
|
||||
MatmulWeightsDecompression::getTestCaseName);
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
@ -1 +1 @@
|
||||
Subproject commit 36c2060a0dc85b4def72ea30823936c2ef861b82
|
||||
Subproject commit ff9205a8b42238e1fba992fad2429b722c4cfed0
|
@ -71,6 +71,8 @@ std::shared_ptr<ov::Node> makeConstant(const ov::element::Type& type,
|
||||
makeNode(ov::element::Type_t::u64);
|
||||
makeNode(ov::element::Type_t::boolean);
|
||||
makeNode(ov::element::Type_t::nf4);
|
||||
makeNode(ov::element::Type_t::u4);
|
||||
makeNode(ov::element::Type_t::i4);
|
||||
#undef makeNode
|
||||
default:
|
||||
throw std::runtime_error("Unhandled precision");
|
||||
|
Loading…
Reference in New Issue
Block a user