[CPU] FullyConnected acceleration with 4bit weights decompression (#20607)

This commit is contained in:
Gorokhov Dmitriy 2023-10-26 01:08:07 +04:00 committed by GitHub
parent 00e2381d04
commit 63299ec217
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 198 additions and 130 deletions

View File

@ -53,6 +53,8 @@ INFERENCE_ENGINE_1_0_DEPRECATED inline ::ngraph::element::Type convertPrecision(
return ::ngraph::element::Type(::ngraph::element::Type_t::boolean);
case Precision::BIN:
return ::ngraph::element::Type(::ngraph::element::Type_t::u1);
case Precision::NF4:
return ::ngraph::element::Type(::ngraph::element::Type_t::nf4);
case Precision::Q78:
case Precision::MIXED:
case Precision::CUSTOM:

View File

@ -21,19 +21,18 @@ namespace intel_cpu {
uint8_t DnnlExtensionUtils::sizeOfDataType(dnnl::memory::data_type dataType) {
switch (dataType) {
case dnnl::memory::data_type::f32:
return 4;
case dnnl::memory::data_type::s32:
return 4;
case dnnl::memory::data_type::bf16:
return 2;
case dnnl::memory::data_type::s8:
return 1;
case dnnl::memory::data_type::u8:
return 1;
case dnnl::memory::data_type::bin:
return 1;
case dnnl::memory::data_type::f16:
return 2;
case dnnl::memory::data_type::s8:
case dnnl::memory::data_type::u8:
case dnnl::memory::data_type::bin:
case dnnl::memory::data_type::nf4:
case dnnl::memory::data_type::s4:
case dnnl::memory::data_type::u4:
return 1;
case dnnl::memory::data_type::undef:
return 0;
default:
@ -58,6 +57,12 @@ memory::data_type DnnlExtensionUtils::IEPrecisionToDataType(const InferenceEngin
return memory::data_type::bin;
case InferenceEngine::Precision::FP16:
return memory::data_type::f16;
case InferenceEngine::Precision::NF4:
return memory::data_type::nf4;
case InferenceEngine::Precision::I4:
return memory::data_type::s4;
case InferenceEngine::Precision::U4:
return memory::data_type::u4;
case InferenceEngine::Precision::UNSPECIFIED:
return memory::data_type::undef;
default: {
@ -82,6 +87,12 @@ InferenceEngine::Precision DnnlExtensionUtils::DataTypeToIEPrecision(memory::dat
return InferenceEngine::Precision::BIN;
case memory::data_type::f16:
return InferenceEngine::Precision::FP16;
case memory::data_type::nf4:
return InferenceEngine::Precision::NF4;
case memory::data_type::s4:
return InferenceEngine::Precision::I4;
case memory::data_type::u4:
return InferenceEngine::Precision::U4;
case memory::data_type::undef:
return InferenceEngine::Precision::UNSPECIFIED;
default: {

View File

@ -251,48 +251,58 @@ void DnnlPostOpsComposer::appendClip(const std::vector<float>& low, const std::v
}
}
MemoryPtr DnnlPostOpsComposer::prepackDecompressionParams(const MemoryCPtr& params_ptr, size_t icBlock) {
// Prepacking params from [oc] to [oc, icBlock] layout, where for each icBlock corresponding parameter is duplicated
MemoryPtr DnnlPostOpsComposer::prepackDecompressionParams(const MemoryCPtr& params_ptr, bool needTranspose) {
const auto shape = params_ptr->getShape().getStaticDims();
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({icBlock * elements_count}));
auto mem = std::make_shared<Memory>(engine, memoryDesc);
size_t dstIdx = 0;
auto decomp_scales_data = static_cast<float*>(params_ptr->getData());
auto decomp_scales_buf = static_cast<float*>(mem->getData());
for (size_t oc = 0; oc < elements_count; oc++) {
for (size_t intIdx = 0; intIdx < icBlock; intIdx++) {
decomp_scales_buf[dstIdx] = decomp_scales_data[oc];
MemoryPtr mem;
auto params_data = static_cast<float*>(params_ptr->getData());
if (needTranspose) {
VectorDims dnnlShape = {shape[0], shape[1]};
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape(dnnlShape));
mem = std::make_shared<Memory>(engine, memoryDesc);
auto memory_buf = static_cast<float*>(mem->getData());
// oi -> io
for (size_t oc = 0; oc < dnnlShape[0]; oc++) {
for (size_t ic = 0; ic < dnnlShape[1]; ic++) {
memory_buf[ic * dnnlShape[0] + oc] = params_data[oc * dnnlShape[1] + ic];
}
}
} else {
VectorDims dnnlShape = {shape[shape.size() - 1], shape[0]};
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape(dnnlShape));
mem = std::make_shared<Memory>(engine, memoryDesc);
auto memory_buf = static_cast<float*>(mem->getData());
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
// io -> io
size_t dstIdx = 0;
for (size_t oc = 0; oc < elements_count; oc++) {
memory_buf[dstIdx] = params_data[oc];
dstIdx++;
}
}
return mem;
}
void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, size_t icBlock) {
void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose) {
if (scales_ptr == nullptr)
return;
const auto shape = scales_ptr->getShape().getStaticDims();
const auto elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
int mask = elements_count > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask);
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(scales_ptr, icBlock);
auto scalesMem = prepackDecompressionParams(scales_ptr, needTranspose);
attr.set_scales_dims(DNNL_ARG_WEIGHTS, DnnlExtensionUtils::convertToDnnlDims(scalesMem->getStaticDims()));
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = scalesMem;
}
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, size_t icBlock) {
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose) {
if (zero_points_ptr == nullptr)
return;
const auto shape = zero_points_ptr->getShape().getStaticDims();
const auto elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
int mask = elements_count > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights zero points mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_zero_points_mask(DNNL_ARG_WEIGHTS, mask);
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(zero_points_ptr, icBlock);
auto zeroPointsMem = prepackDecompressionParams(zero_points_ptr, needTranspose);
attr.set_zero_points_dims(DNNL_ARG_WEIGHTS, DnnlExtensionUtils::convertToDnnlDims(zeroPointsMem->getStaticDims()));
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zeroPointsMem;
}
} // namespace intel_cpu

View File

@ -42,8 +42,8 @@ public:
bool appendLinear(const std::vector<float>& scale, const std::vector<float>& shift, bool isLastPostOp, bool allowBinary = true);
void appendClip(const std::vector<float>& low, const std::vector<float>& high);
void appendDecompressionScales(const MemoryCPtr& scales_ptr, size_t icBlock);
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, size_t icBlock);
void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose);
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose);
const VectorDims& getOutputDims() {
return outputDims;
@ -69,7 +69,7 @@ private:
void updateWeiScales();
void updateDestScales();
MemoryPtr prepackDecompressionParams(const MemoryCPtr& params_ptr, size_t icBlock);
MemoryPtr prepackDecompressionParams(const MemoryCPtr& params_ptr, bool needTranspose);
};
} // namespace intel_cpu

View File

@ -286,7 +286,8 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
}
void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8, InferenceEngine::Precision::NF4};
std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8, InferenceEngine::Precision::NF4,
InferenceEngine::Precision::U4, InferenceEngine::Precision::I4};
const std::set<InferenceEngine::Precision> supportedDataPrecisions{InferenceEngine::Precision::FP32, InferenceEngine::Precision::BF16};
auto expectedNode = [](NodePtr node, Type expectedType) {
return node->getType() == expectedType && node->getChildEdges().size() == 1;
@ -335,7 +336,28 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
continue;
}
const auto convertNode = withSubtract ? subtractNode->getParentEdgesAtPort(0)[0]->getParent() : mulParent;
const bool withPowerStatic = mulParent->getAlgorithm() == Algorithm::EltwisePowerStatic;
NodePtr powerStaticNode;
if (withPowerStatic) {
powerStaticNode = mulParent;
if (auto *eltwiseNode = dynamic_cast<node::Eltwise *>(powerStaticNode.get())) {
if (eltwiseNode->getAlpha() != 1 || eltwiseNode->getBeta() != 1)
continue;
} else {
continue;
}
}
// Both operations fallbacks on IP zero-point attribute and cannot be combined
if (withSubtract && withPowerStatic)
continue;
auto convertNode = mulParent;
if (withSubtract)
convertNode = subtractNode->getParentEdgesAtPort(0)[0]->getParent();
if (withPowerStatic)
convertNode = powerStaticNode->getParentEdgesAtPort(0)[0]->getParent();
if (!expectedNode(convertNode, Type::Convert))
continue;
const auto weightsNode = convertNode->getParentEdgesAtPort(0)[0]->getParent();
@ -347,6 +369,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
continue;
if (withSubtract && subtractConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
continue;
if (withPowerStatic && powerStaticNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
continue;
if (supportedDataPrecisions.find(fcNode->getOriginalInputPrecisionAtPort(0)) == supportedDataPrecisions.end())
continue;
if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end())
@ -361,6 +385,7 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
VectorDims decompressionConstShape;
const auto fcInputWeightsShape = fcNode->getInputShapeAtPort(1);
int groupNum = 1;
// Ordinary case: one decompression group
if (fcInputWeightsShape.getRank() == weightsShape.getRank()) {
const auto& out_channels = fcInputWeightsShape.getDims()[0];
@ -377,6 +402,7 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
const auto& O = withTranspose ? *weights_dims.rbegin() : *(weights_dims.rbegin() + 2);
// Group decompression is applied by O and N dims
decompressionConstShape = withTranspose ? VectorDims{N, 1, O} : VectorDims{O, N, 1};
groupNum = N;
}
if (multiplyConstNode->getOutputShapeAtPort(0).getDims() != decompressionConstShape)
continue;
@ -384,7 +410,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
continue;
// HW specific shape limitations
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx)) {
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx) &&
fcNode->getOriginalInputPrecisionAtPort(0) == InferenceEngine::Precision::BF16) {
// OneDNN AMX IP implementation has limited shapes support due to performance considerations. As a current solution conditions below are copied
// from OneDNN to make sure correct IP impl will be used since fallback one doesn't support weights decompression feature.
size_t OC = fcInputWeightsShape.getDims()[0];
@ -398,10 +425,38 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
continue;
}
size_t IC = fcInputWeightsShape.getDims()[1];
// OneDNN IP primitive provides limited decompression params support
if (IC % groupNum != 0 || IC / groupNum < 4) {
continue;
}
// Fusion processing
fcNode->fuseDecompressionMultiply(multiplyConstNode);
if (withSubtract)
fcNode->fuseDecompressionSubtract(subtractConstNode);
auto *multiplyInputNode = dynamic_cast<node::Input *>(multiplyConstNode.get());
if (!multiplyInputNode) {
IE_THROW() << "Cannot cast " << multiplyInputNode->getName() << " to Input node";
}
fcNode->fuseDecompressionMultiply(multiplyInputNode->getMemoryPtr());
if (withSubtract) {
auto *subtractInputNode = dynamic_cast<node::Input *>(subtractConstNode.get());
if (!subtractInputNode) {
IE_THROW() << "Cannot cast " << subtractInputNode->getName() << " to Input node";
}
fcNode->fuseDecompressionSubtract(subtractInputNode->getMemoryPtr());
}
if (withPowerStatic) {
auto *eltwiseNode = dynamic_cast<node::Eltwise *>(powerStaticNode.get());
if (!eltwiseNode) {
IE_THROW() << "Cannot cast " << eltwiseNode->getName() << " to Eltwise node";
}
VectorDims memoryDims(decompressionConstShape.size(), 1);
CpuBlockedMemoryDesc memoryDesc(Precision::FP32, Shape(memoryDims));
auto memory = std::make_shared<Memory>(graph.getEngine(), memoryDesc, nullptr, false);
(static_cast<float *>(memory->getData()))[0] = -1.f * eltwiseNode->getGamma();
fcNode->fuseDecompressionSubtract(memory);
}
fcNode->addOriginalLayer(multiplyNode->getOriginalLayers());
fcNode->addOriginalLayer(convertNode->getOriginalLayers());
@ -411,12 +466,18 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
auto subtractConstEdge = subtractConstNode->getChildEdges()[0].lock();
graph.RemoveEdge(subtractConstEdge);
}
if (withPowerStatic) {
fcNode->addOriginalLayer(powerStaticNode->getOriginalLayers());
}
auto multiplyConstEdge = multiplyConstNode->getChildEdges()[0].lock();
graph.RemoveEdge(multiplyConstEdge);
graph.DropNode(convertNode);
if (withSubtract)
graph.DropNode(subtractNode);
if (withPowerStatic)
graph.DropNode(powerStaticNode);
graph.DropNode(multiplyNode);
const auto& weightsPrecision = weightsNode->getOriginalOutputPrecisionAtPort(0);

View File

@ -208,7 +208,8 @@ void FullyConnected::getSupportedDescriptors() {
useSparseWeights = useSparseWeightsDecompression();
useWeightsDecompressionImpl = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(inputDataType, memory::data_type::f32, memory::data_type::bf16) &&
weightsDataType == memory::data_type::u8;
one_of(weightsDataType, memory::data_type::u8, memory::data_type::nf4,
memory::data_type::u4, memory::data_type::s4);
// revert back outputDataType on special cases
if (inputDataType == memory::data_type::f32) {
@ -724,15 +725,10 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
NodeDesc *selected_pd = getSelectedPrimitiveDescriptor();
if (selected_pd == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
// OneDNN API doesn't provide an abilitiy to query optimal layout for runtime attributes
// As workaround we assume that all AMX IP implementations use equal internal IC block size for weights layout
// and prepack runtime attributes accordingly for better performance
bool withAMX = selected_pd->getImplementationType() & impl_desc_type::amx;
int icBlock = withAMX ? 2 : 1;
if (decompressionMultiplyPtr)
dnnlpoc.appendDecompressionScales(decompressionMultiplyPtr, icBlock);
dnnlpoc.appendDecompressionScales(decompressionMultiplyPtr, !weightsNonTransposed);
if (decompressionSubtractPtr)
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr, icBlock);
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr, !weightsNonTransposed);
for (size_t i = 0; i < fusedWith.size(); ++i) {
auto& node = fusedWith[i];
@ -1132,30 +1128,25 @@ bool FullyConnected::useSparseWeightsDecompression() {
return true;
}
void FullyConnected::fuseDecompressionMultiply(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionMultiplyPtr);
void FullyConnected::fuseDecompressionMultiply(const MemoryCPtr& memory) {
fuseDecompressionConstant(memory, decompressionMultiplyPtr);
}
void FullyConnected::fuseDecompressionSubtract(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionSubtractPtr);
void FullyConnected::fuseDecompressionSubtract(const MemoryCPtr& memory) {
fuseDecompressionConstant(memory, decompressionSubtractPtr);
}
void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr) {
auto *constInputNode = dynamic_cast<node::Input *>(constData.get());
if (!constInputNode) {
IE_THROW() << "Cannot cast " << constData->getName() << " to Input";
}
void FullyConnected::fuseDecompressionConstant(const MemoryCPtr& memory, MemoryCPtr& decompressionValuesPtr) {
const auto decompression_prc = InferenceEngine::Precision::FP32;
if (constInputNode->getOriginalOutputPrecisionAtPort(0) == decompression_prc) {
decompressionValuesPtr = constInputNode->getMemoryPtr();
if (memory->getDesc().getPrecision() == decompression_prc) {
decompressionValuesPtr = memory;
} else {
const auto constBlob = constInputNode->getMemoryPtr();
DnnlBlockedMemoryDesc memoryDesc(decompression_prc, constBlob->getShape());
DnnlBlockedMemoryDesc memoryDesc(decompression_prc, memory->getShape());
decompressionValuesPtr = std::make_shared<Memory>(getEngine(), memoryDesc, nullptr, false);
const auto elementsCount = constBlob->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
cpu_convert(constBlob->getData(),
const auto elementsCount = memory->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
cpu_convert(memory->getData(),
decompressionValuesPtr->getData(),
DnnlExtensionUtils::DataTypeToIEPrecision(constBlob->getDataType()),
DnnlExtensionUtils::DataTypeToIEPrecision(memory->getDataType()),
Precision::FP32,
elementsCount);
}

View File

@ -60,8 +60,8 @@ public:
this->weightsNonTransposed = weightsNonTransposed;
}
void fuseDecompressionMultiply(const NodePtr& constData);
void fuseDecompressionSubtract(const NodePtr& constData);
void fuseDecompressionMultiply(const MemoryCPtr& memory);
void fuseDecompressionSubtract(const MemoryCPtr& memory);
private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
@ -99,7 +99,7 @@ private:
const dnnl::engine& engine);
bool canBeExecutedInConv1x1() const;
void fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr);
void fuseDecompressionConstant(const MemoryCPtr& memory, MemoryCPtr& decompressionValuesPtr);
// sparse weights
bool useSparseWeights = false;

View File

@ -201,11 +201,16 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
} else {
// We need to fuse Transpose to MatMul to have a simpler callback for the next transformation
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeMatMul);
const ov::element::TypeVector decompression_precisions{
ov::element::u8,
// TODO: Uncomment when group decompression is supported
// ov::element::nf4
ov::element::TypeVector decompression_precisions{
ov::element::u8
};
// We don't have BF16/FP16 FullyConnected kernels to work with 4bits compressed weights
// Convert node doesn't support 4bit precisions -> fallback on constant folding
if (inferencePrecision == ov::element::f32) {
decompression_precisions.push_back(ov::element::u4);
decompression_precisions.push_back(ov::element::i4);
decompression_precisions.push_back(ov::element::nf4);
}
// MarkDequantizationSubgraph is used even in non-LPT pipeline on X64 platforms
// in order to keep compressed MatMul weights with decompression operations as is
CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, true);
@ -223,15 +228,13 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
if (ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
} else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
// TODO: Uncomment when group decompression is supported
// if (ov::is_type<ov::opset1::Reshape>(consumer)) {
// consumer = get_single_consumer(consumer);
// if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
// return false;
// }
// }
if (ov::is_type<ov::opset1::Convert>(consumer)) {
if (consumer != nullptr && ov::is_type<ov::opset1::Convert>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;

View File

@ -142,7 +142,8 @@ protected:
transformed_weights_shape[in_channel_idx] = weights_shape[0] / group_size;
transformed_weights_shape.insert(transformed_weights_shape.begin() + in_channel_idx + 1, group_size);
}
auto weights = ngraph::builder::makeConstant<uint8_t>(weights_precision, transformed_weights_shape, {}, true);
auto weights = ngraph::builder::makeConstant<int8_t>(weights_precision, transformed_weights_shape, {}, true, 7);
weights->set_friendly_name("Compressed_weights");
auto weights_convert = std::make_shared<ngraph::opset1::Convert>(weights, decompression_precision);
@ -164,7 +165,7 @@ protected:
if (reshape_on_decompression_constant)
scaleshift_const_shape.erase(std::remove(scaleshift_const_shape.begin(), scaleshift_const_shape.end(), 1), scaleshift_const_shape.end());
if (add_subtract) {
auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true);
auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true, 7);
std::shared_ptr<ov::Node> shift_convert = std::make_shared<ngraph::opset1::Convert>(shift_const, decompression_precision);
if (reshape_on_decompression_constant) {
auto shift_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape);
@ -268,10 +269,7 @@ protected:
void checkResults() {
const auto& test_param = GetParam();
const auto& weights_precision = std::get<1>(test_param);
// TODO: remove this condition when group decompression is supported
if (weights_precision == ov::element::nf4 || std::get<0>(test_param).weights_group_size != -1) {
return;
}
bool weights_found = false;
for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) {
if (n->get_friendly_name() == "Compressed_weights") {
@ -301,48 +299,37 @@ std::vector<std::map<std::string, std::string>> filterAdditionalConfigBasic() {
std::vector<std::map<std::string, std::string>> additional_config = {CPUTestUtils::cpuEmptyPluginConfig};
return additional_config;
}
std::vector<std::map<std::string, std::string>> filterAdditionalConfigBig() {
std::vector<std::map<std::string, std::string>> additional_config = {CPUTestUtils::cpuEmptyPluginConfig};
std::vector<std::map<std::string, std::string>> filterAdditionalConfigAMX() {
std::vector<std::map<std::string, std::string>> additional_config = {};
if (with_cpu_x86_avx512_core_amx())
additional_config.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}});
return additional_config;
}
bool shouldUseDecompressionKernelBig() {
// No decompression support on non-avx systems
if (!with_cpu_x86_avx2())
return false;
return true;
}
bool shouldUseDecompressionKernelBasic() {
// AMX decompression support has shape limitations
if (with_cpu_x86_avx512_core_amx())
return false;
return shouldUseDecompressionKernelBig();
}
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8, ov::element::nf4};
const std::vector<ov::test::ElementType> decompression_precisions = {ov::element::f32};
const std::vector<ov::test::ElementType> weights_precisions_basic = {ov::element::u8,
ov::element::u4,
ov::element::i4,
ov::element::nf4};
const std::vector<ov::test::ElementType> weights_precisions_amx = {ov::element::u8};
const std::vector<ShapeParams> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {16, 32}},
{{{}, {{1, 4, 16}}}, {16, 32}, 2ul},
{{{}, {{1, 8, 16}}}, {16, 32}, 4ul},
{{{}, {{1, 4, 16}}}, {1, 16, 32}},
{{{}, {{10, 40, 496}}}, {1, 496, 240}},
{{{}, {{1, 4, 48}}}, {48, 256}},
{{{}, {{11, 339, 377}}}, {377, 335}},
};
const std::vector<ShapeParams> input_shapes_big = {
{{{}, {{1, 11, 154}}}, {154, 77}, 154ul},
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
};
const std::vector<ShapeParams> input_shapes_amx = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
{{{-1, 1, 4096}, {{1, 1, 4096}}}, {4096, 3840}, 128ul},
{{{}, {{1, 4, 32}}}, {32, 256}},
{{{}, {{1, 4, 512}}}, {512, 256}},
{{{}, {{1, 16, 32}}}, {32, 64}},
{{{}, {{2, 4, 32}}}, {32, 65}},
{{{}, {{3, 12, 768}}}, {768, 1024}},
{{{}, {{11, 339, 577}}}, {577, 335}},
{{{}, {{1, 1, 256}}}, {256, 128}, 64ul},
};
const std::vector<fusingSpecificParams> fusingParamsSet {
emptyFusingSpec,
@ -352,35 +339,36 @@ const std::vector<fusingSpecificParams> fusingParamsSet {
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(weights_precisions_basic),
::testing::ValuesIn(decompression_precisions),
::testing::Values(true),
::testing::Values(true),
::testing::Values(true),
::testing::ValuesIn(filterAdditionalConfigBasic()),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(shouldUseDecompressionKernelBasic())),
::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_big,
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_amx,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_big),
::testing::ValuesIn(weights_precisions),
::testing::Combine(::testing::ValuesIn(input_shapes_amx),
::testing::ValuesIn(weights_precisions_amx),
::testing::ValuesIn(decompression_precisions),
::testing::Values(true),
::testing::Values(true),
::testing::Values(true),
::testing::ValuesIn(filterAdditionalConfigBig()),
::testing::ValuesIn(filterAdditionalConfigAMX()),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(shouldUseDecompressionKernelBig())),
::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName);
const std::vector<ShapeParams> input_shapes_corner_cases_basic = {
{{{-1, -1, -1}, {{1, 4, 16}}}, {1, 16, 32}},
{{{-1, -1, -1}, {{1, 4, 16}}}, {16, 32}},
{{{-1, -1, -1}, {{1, 4, 16}}}, {16, 32}, 4ul},
{{{-1, -1, -1}, {{1, 1, 4096}}}, {4096, 4096}, 128ul},
};
const std::vector<ShapeParams> input_shapes_corner_cases_big = {
const std::vector<ShapeParams> input_shapes_corner_cases_amx = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
{{{-1, -1, -1}, {{1, 1, 4096}}}, {4096, 4096}, 128ul},
};
@ -393,27 +381,27 @@ const std::vector<ov::test::ElementType> decompression_precisions_corner_cases =
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_basic,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_basic),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(weights_precisions_basic),
::testing::ValuesIn(decompression_precisions_corner_cases),
::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression),
::testing::ValuesIn(filterAdditionalConfigBasic()),
::testing::Values(emptyFusingSpec),
::testing::Values(shouldUseDecompressionKernelBasic())),
::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_big,
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_amx,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_big),
::testing::ValuesIn(weights_precisions),
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_amx),
::testing::ValuesIn(weights_precisions_amx),
::testing::ValuesIn(decompression_precisions_corner_cases),
::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression),
::testing::ValuesIn(filterAdditionalConfigBig()),
::testing::ValuesIn(filterAdditionalConfigAMX()),
::testing::Values(emptyFusingSpec),
::testing::Values(shouldUseDecompressionKernelBig())),
::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

@ -1 +1 @@
Subproject commit 36c2060a0dc85b4def72ea30823936c2ef861b82
Subproject commit ff9205a8b42238e1fba992fad2429b722c4cfed0

View File

@ -71,6 +71,8 @@ std::shared_ptr<ov::Node> makeConstant(const ov::element::Type& type,
makeNode(ov::element::Type_t::u64);
makeNode(ov::element::Type_t::boolean);
makeNode(ov::element::Type_t::nf4);
makeNode(ov::element::Type_t::u4);
makeNode(ov::element::Type_t::i4);
#undef makeNode
default:
throw std::runtime_error("Unhandled precision");