[CPU] Group & NF4 decompression transformation support (#20039)

This commit is contained in:
Vladislav Golubev 2023-10-11 13:25:00 +02:00 committed by GitHub
parent 0bb6450398
commit 5894fbe69d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 271 additions and 151 deletions

View File

@ -101,6 +101,8 @@ INFERENCE_ENGINE_1_0_DEPRECATED inline Precision convertPrecision(const ::ngraph
return Precision(Precision::BIN);
case ::ngraph::element::Type_t::boolean:
return Precision(Precision::BOOL);
case ::ngraph::element::Type_t::nf4:
return Precision(Precision::NF4);
case ::ngraph::element::Type_t::dynamic:
return Precision(Precision::UNSPECIFIED);
default:

View File

@ -41,6 +41,7 @@ public:
FP16 = 11, /**< 16bit floating point value, 5 bit for exponent, 10 bit for mantisa */
BF16 = 12, /**< 16bit floating point value, 8 bit for exponent, 7 bit for mantisa*/
FP64 = 13, /**< 64bit floating point value */
NF4 = 14, /**< 4bit normalized float value */
Q78 = 20, /**< 16bit specific signed fixed point precision */
I16 = 30, /**< 16bit signed integer value */
U4 = 39, /**< 4bit unsigned integer value */
@ -131,6 +132,7 @@ public:
CASE(FP64, double);
CASE2(FP16, int16_t, uint16_t);
CASE2(BF16, int16_t, uint16_t);
CASE(NF4, int8_t);
CASE2(I4, int8_t, uint8_t);
CASE(I8, int8_t);
CASE(I16, int16_t);
@ -249,24 +251,11 @@ public:
static Precision FromStr(const std::string& str) {
static const std::unordered_map<std::string, ePrecision> names = {
#define PRECISION_NAME(s) {#s, s}
PRECISION_NAME(Q78),
PRECISION_NAME(BOOL),
PRECISION_NAME(BF16),
PRECISION_NAME(I4),
PRECISION_NAME(I8),
PRECISION_NAME(I16),
PRECISION_NAME(I32),
PRECISION_NAME(I64),
PRECISION_NAME(U4),
PRECISION_NAME(U8),
PRECISION_NAME(U16),
PRECISION_NAME(U32),
PRECISION_NAME(U64),
PRECISION_NAME(FP32),
PRECISION_NAME(FP64),
PRECISION_NAME(FP16),
PRECISION_NAME(MIXED),
PRECISION_NAME(BIN),
PRECISION_NAME(Q78), PRECISION_NAME(BOOL), PRECISION_NAME(BF16), PRECISION_NAME(I4),
PRECISION_NAME(I8), PRECISION_NAME(I16), PRECISION_NAME(I32), PRECISION_NAME(I64),
PRECISION_NAME(U4), PRECISION_NAME(U8), PRECISION_NAME(U16), PRECISION_NAME(U32),
PRECISION_NAME(U64), PRECISION_NAME(FP32), PRECISION_NAME(FP64), PRECISION_NAME(FP16),
PRECISION_NAME(MIXED), PRECISION_NAME(NF4), PRECISION_NAME(BIN),
#undef PRECISION_NAME
};
auto i = names.find(str);
@ -311,7 +300,8 @@ public:
(precisionInfo.value == Precision::I16) || (precisionInfo.value == Precision::I8) ||
(precisionInfo.value == Precision::I32) || (precisionInfo.value == Precision::I64) ||
(precisionInfo.value == Precision::BIN) || (precisionInfo.value == Precision::BF16) ||
(precisionInfo.value == Precision::CUSTOM) || (precisionInfo.value == Precision::I4);
(precisionInfo.value == Precision::CUSTOM) || (precisionInfo.value == Precision::I4) ||
(precisionInfo.value == Precision::NF4);
}
protected:
@ -359,6 +349,7 @@ protected:
CASE(FP64);
CASE(FP16);
CASE(BF16);
CASE(NF4);
CASE(I4);
CASE(I8);
CASE(I16);
@ -475,6 +466,12 @@ struct INFERENCE_ENGINE_1_0_DEPRECATED PrecisionTrait<Precision::BIN> {
enum { is_float = false };
};
template <>
struct INFERENCE_ENGINE_1_0_DEPRECATED PrecisionTrait<Precision::NF4> {
using value_type = int8_t;
enum { is_float = false };
};
template <class T>
INFERENCE_ENGINE_1_0_DEPRECATED inline uint8_t type_size_or_zero() {
return sizeof(T);
@ -499,7 +496,7 @@ INFERENCE_ENGINE_1_0_DEPRECATED inline Precision::PrecisionInfo Precision::makeP
Precision::PrecisionInfo info;
info.name = name;
size_t nBits = precision == BIN ? 1 : (precision == U4 || precision == I4) ? 4 : 8;
size_t nBits = precision == BIN ? 1 : (precision == U4 || precision == I4 || precision == NF4) ? 4 : 8;
info.bitsSize = nBits * type_size_or_zero<typename PrecisionTrait<precision>::value_type>();
info.isFloat = PrecisionTrait<precision>::is_float;
info.value = precision;

View File

@ -251,41 +251,48 @@ void DnnlPostOpsComposer::appendClip(const std::vector<float>& low, const std::v
}
}
MemoryPtr DnnlPostOpsComposer::prepackDecompressionParams(const std::vector<float>& params, size_t icBlock) {
MemoryPtr DnnlPostOpsComposer::prepackDecompressionParams(const MemoryCPtr& params_ptr, size_t icBlock) {
// Prepacking params from [oc] to [oc, icBlock] layout, where for each icBlock corresponding parameter is duplicated
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({icBlock * params.size()}));
const auto shape = params_ptr->getShape().getStaticDims();
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({icBlock * elements_count}));
auto mem = std::make_shared<Memory>(engine, memoryDesc);
size_t dstIdx = 0;
auto decomp_scales_data = static_cast<float*>(params_ptr->getData());
auto decomp_scales_buf = static_cast<float*>(mem->getData());
for (size_t oc = 0; oc < params.size(); oc++) {
for (size_t oc = 0; oc < elements_count; oc++) {
for (size_t intIdx = 0; intIdx < icBlock; intIdx++) {
decomp_scales_buf[dstIdx] = params[oc];
decomp_scales_buf[dstIdx] = decomp_scales_data[oc];
dstIdx++;
}
}
return mem;
}
void DnnlPostOpsComposer::appendDecompressionScales(const std::vector<float>& scales, size_t icBlock) {
if (scales.empty())
void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, size_t icBlock) {
if (scales_ptr == nullptr)
return;
int mask = scales.size() > 1 ? weightScaleMaskPerChannel : 0;
const auto shape = scales_ptr->getShape().getStaticDims();
const auto elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
int mask = elements_count > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask);
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(scales, icBlock);
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(scales_ptr, icBlock);
}
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const std::vector<float>& zero_points, size_t icBlock) {
if (zero_points.empty())
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, size_t icBlock) {
if (zero_points_ptr == nullptr)
return;
int mask = zero_points.size() > 1 ? weightScaleMaskPerChannel : 0;
const auto shape = zero_points_ptr->getShape().getStaticDims();
const auto elements_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
int mask = elements_count > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights zero points mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_zero_points_mask(DNNL_ARG_WEIGHTS, mask);
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(zero_points, icBlock);
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(zero_points_ptr, icBlock);
}
} // namespace intel_cpu

View File

@ -42,8 +42,8 @@ public:
bool appendLinear(const std::vector<float>& scale, const std::vector<float>& shift, bool isLastPostOp, bool allowBinary = true);
void appendClip(const std::vector<float>& low, const std::vector<float>& high);
void appendDecompressionScales(const std::vector<float>& scales, size_t icBlock);
void appendDecompressionZeroPoints(const std::vector<float>& zero_points, size_t icBlock);
void appendDecompressionScales(const MemoryCPtr& scales_ptr, size_t icBlock);
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, size_t icBlock);
const VectorDims& getOutputDims() {
return outputDims;
@ -69,7 +69,7 @@ private:
void updateWeiScales();
void updateDestScales();
MemoryPtr prepackDecompressionParams(const std::vector<float>& params, size_t icBlock);
MemoryPtr prepackDecompressionParams(const MemoryCPtr& params_ptr, size_t icBlock);
};
} // namespace intel_cpu

View File

@ -286,7 +286,7 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
}
void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
const std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8};
std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8, InferenceEngine::Precision::NF4};
const std::set<InferenceEngine::Precision> supportedDataPrecisions{InferenceEngine::Precision::FP32, InferenceEngine::Precision::BF16};
auto expectedNode = [](NodePtr node, Type expectedType) {
return node->getType() == expectedType && node->getChildEdges().size() == 1;
@ -301,11 +301,19 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
if (fcNode == nullptr)
continue;
const auto parent = fcNode->getParentEdgesAtPort(1)[0]->getParent();
auto parent = fcNode->getParentEdgesAtPort(1)[0]->getParent();
const bool withTranspose = parent->getType() == Type::Transpose;
const NodePtr transposeNode = withTranspose ? parent : nullptr;
if (transposeNode)
parent = transposeNode->getParentEdgesAtPort(0)[0]->getParent();
const auto multiplyNode = withTranspose ? parent->getParentEdgesAtPort(0)[0]->getParent() : parent;
const bool withReshape = parent->getType() == Type::Reshape;
const auto reshapeNode = withReshape ? parent : nullptr;
if (reshapeNode) {
parent = reshapeNode->getParentEdgesAtPort(0)[0]->getParent();
}
const auto multiplyNode = parent;
if (!expectedNode(multiplyNode, Type::Eltwise) || multiplyNode->getAlgorithm() != Algorithm::EltwiseMultiply ||
!multiplyNode->isConstant())
continue;
@ -346,23 +354,41 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
// Shape limitations
const auto weightsShape = weightsNode->getOutputShapeAtPort(0);
const auto fcInputWeightsShape = multiplyNode->getOutputShapeAtPort(0);
if (weightsShape != fcInputWeightsShape)
if (weightsShape != multiplyNode->getOutputShapeAtPort(0))
continue;
if (reshapeNode && (reshapeNode->getInputShapeAtPort(0).getRank() != 3 || reshapeNode->getOutputShapeAtPort(0).getRank() != 2))
continue;
const auto expectedDims = withTranspose ? VectorDims{1, weightsShape.getDims()[1]}
: VectorDims{weightsShape.getDims()[0], 1};
if (multiplyConstNode->getOutputShapeAtPort(0).getDims() != expectedDims)
VectorDims decompressionConstShape;
const auto fcInputWeightsShape = fcNode->getInputShapeAtPort(1);
// Ordinary case: one decompression group
if (fcInputWeightsShape.getRank() == weightsShape.getRank()) {
const auto& out_channels = fcInputWeightsShape.getDims()[0];
decompressionConstShape = withTranspose ? VectorDims{1, out_channels} : VectorDims{out_channels, 1};
} else {
// Group decompression case: last 3 dimension (there could be also prepending '1's in the beginning) of weights shape must be:
// [N, G, O], if transpose = true
// [O, N, G], otherwise.
// O - output channels
// N - number of groups
// G - group size
const auto& weights_dims = weightsShape.getStaticDims();
const auto& N = withTranspose ? *(weights_dims.rbegin() + 2) : *(weights_dims.rbegin() + 1);
const auto& O = withTranspose ? *weights_dims.rbegin() : *(weights_dims.rbegin() + 2);
// Group decompression is applied by O and N dims
decompressionConstShape = withTranspose ? VectorDims{N, 1, O} : VectorDims{O, N, 1};
}
if (multiplyConstNode->getOutputShapeAtPort(0).getDims() != decompressionConstShape)
continue;
if (withSubtract && subtractConstNode->getOutputShapeAtPort(0).getDims() != expectedDims)
if (withSubtract && subtractConstNode->getOutputShapeAtPort(0).getDims() != decompressionConstShape)
continue;
// HW specific shape limitations
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx)) {
// OneDNN AMX IP implementation has limited shapes support due to performance considerations. As a current solution conditions below are copied
// from OneDNN to make sure correct IP impl will be used since fallback one doesn't support weights decompression feature.
size_t OC = withTranspose ? weightsShape.getDims()[1] : weightsShape.getDims()[0];
size_t IC = withTranspose ? weightsShape.getDims()[0] : weightsShape.getDims()[1];
size_t OC = fcInputWeightsShape.getDims()[0];
size_t IC = fcInputWeightsShape.getDims()[1];
size_t simdWidth = 16;
size_t vnniFactor = 2;
size_t maxSize = 512;
@ -398,6 +424,10 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
transposeNode->setOriginalInputPrecisionAtPort(0, weightsPrecision);
transposeNode->setOriginalOutputPrecisionAtPort(0, weightsPrecision);
}
if (withReshape) {
reshapeNode->setOriginalInputPrecisionAtPort(0, weightsPrecision);
reshapeNode->setOriginalOutputPrecisionAtPort(0, weightsPrecision);
}
fcNode->setOriginalInputPrecisionAtPort(1, weightsPrecision);
}
}

View File

@ -729,10 +729,10 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
// and prepack runtime attributes accordingly for better performance
bool withAMX = selected_pd->getImplementationType() & impl_desc_type::amx;
int icBlock = withAMX ? 2 : 1;
if (!decompressionMultiply.empty())
dnnlpoc.appendDecompressionScales(decompressionMultiply, icBlock);
if (!decompressionSubtract.empty())
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtract, icBlock);
if (decompressionMultiplyPtr)
dnnlpoc.appendDecompressionScales(decompressionMultiplyPtr, icBlock);
if (decompressionSubtractPtr)
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr, icBlock);
for (size_t i = 0; i < fusedWith.size(); ++i) {
auto& node = fusedWith[i];
@ -1133,26 +1133,32 @@ bool FullyConnected::useSparseWeightsDecompression() {
}
void FullyConnected::fuseDecompressionMultiply(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionMultiply);
fuseDecompressionConstant(constData, decompressionMultiplyPtr);
}
void FullyConnected::fuseDecompressionSubtract(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionSubtract);
fuseDecompressionConstant(constData, decompressionSubtractPtr);
}
void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, std::vector<float>& decompressionValues) {
void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr) {
auto *constInputNode = dynamic_cast<node::Input *>(constData.get());
if (!constInputNode) {
IE_THROW() << "Cannot cast " << constData->getName() << " to Input";
}
auto constBlob = constInputNode->getMemoryPtr();
const auto elementsCount = constBlob->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
decompressionValues.resize(elementsCount);
cpu_convert(constBlob->getData(),
&decompressionValues[0],
DnnlExtensionUtils::DataTypeToIEPrecision(constBlob->getDataType()),
Precision::FP32,
elementsCount);
const auto decompression_prc = InferenceEngine::Precision::FP32;
if (constInputNode->getOriginalOutputPrecisionAtPort(0) == decompression_prc) {
decompressionValuesPtr = constInputNode->getMemoryPtr();
} else {
const auto constBlob = constInputNode->getMemoryPtr();
DnnlBlockedMemoryDesc memoryDesc(decompression_prc, constBlob->getShape());
decompressionValuesPtr = std::make_shared<Memory>(getEngine(), memoryDesc, nullptr, false);
const auto elementsCount = constBlob->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
cpu_convert(constBlob->getData(),
decompressionValuesPtr->getData(),
DnnlExtensionUtils::DataTypeToIEPrecision(constBlob->getDataType()),
Precision::FP32,
elementsCount);
}
}
DnnlMemoryDescPtr FullyConnected::makeTransposedWeightDescriptor(DnnlMemoryDescPtr desc) {

View File

@ -61,10 +61,7 @@ public:
}
void fuseDecompressionMultiply(const NodePtr& constData);
const std::vector<float>& getDecompressionMultiply() const { return decompressionMultiply; }
void fuseDecompressionSubtract(const NodePtr& constData);
const std::vector<float>& getDecompressionSubtract() const { return decompressionSubtract; }
private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
@ -102,7 +99,7 @@ private:
const dnnl::engine& engine);
bool canBeExecutedInConv1x1() const;
void fuseDecompressionConstant(const NodePtr& constData, std::vector<float>& decompressionValues);
void fuseDecompressionConstant(const NodePtr& constData, MemoryCPtr& decompressionValuesPtr);
// sparse weights
bool useSparseWeights = false;
@ -121,8 +118,8 @@ private:
void prepareWeightsUsingDummyShape();
#endif
bool useWeightsDecompressionImpl = false;
std::vector<float> decompressionSubtract;
std::vector<float> decompressionMultiply;
MemoryCPtr decompressionSubtractPtr = nullptr;
MemoryCPtr decompressionMultiplyPtr = nullptr;
// FC with transposed weights
bool weightsNonTransposed = false;

View File

@ -207,9 +207,16 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
if (useLpt) {
CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkDequantizationSubgraph, defaultPrecisions);
} else {
// We need to fuse Transpose to MatMul to have a simpler callback for the next transformation
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeMatMul);
const ov::element::TypeVector decompression_precisions{
ov::element::u8,
// TODO: Uncomment when group decompression is supported
// ov::element::nf4
};
// MarkDequantizationSubgraph is used even in non-LPT pipeline on X64 platforms
// in order to keep compressed u8 MatMul weights with decompression operations as is
CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, ov::element::TypeVector{ov::element::u8}, true);
// in order to keep compressed MatMul weights with decompression operations as is
CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, true);
CPU_SET_CALLBACK_X64(manager, [](const_node_ptr &node) -> bool {
auto get_single_consumer = [](const_node_ptr &node) -> std::shared_ptr<ov::Node> {
const auto consumers = node->get_output_target_inputs(0);
@ -224,12 +231,14 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
if (ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
} else if (ov::is_type<ov::opset1::Transpose>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
// TODO: Uncomment when group decompression is supported
// else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
// consumer = get_single_consumer(consumer);
// if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
// return false;
// }
// }
return true;
}, ov::pass::MarkDequantizationSubgraph);
}

View File

@ -14,9 +14,9 @@ using namespace ov::test;
namespace SubgraphTestsDefinitions {
/*
* Subtract_const(U8)
* Subtract_const(U8/NF4)
* /
* Weights(U8) Convert(F32)
* Weights(U8/NF4) Convert(F32)
* | /
* Convert(F32) Reshape
* \ / Multiply_const(F32)
@ -31,21 +31,34 @@ namespace SubgraphTestsDefinitions {
* |
* Bias
*/
using MatmulWeightsDecompressionParams = std::tuple<std::vector<InputShape>, // input shapes
ov::test::ElementType, // weights precision
bool, // transpose on weights
bool, // decompression subtract
bool, // reshape on decompression constants
struct ShapeParams {
ShapeParams() = default;
ShapeParams(InputShape data_shape, ov::Shape weights_shape, int weights_group_size = -1)
: data_shape(std::move(data_shape)),
weights_shape(std::move(weights_shape)),
weights_group_size(weights_group_size) {}
InputShape data_shape;
ov::Shape weights_shape;
// Decompression group size. If the value is equal to -1, ordinary decompression is used
int weights_group_size;
};
using MatmulWeightsDecompressionParams = std::tuple<ShapeParams,
ov::test::ElementType, // weights precision
bool, // transpose on weights
bool, // decompression subtract
bool, // reshape on decompression constants
std::map<std::string, std::string>, // additional config
fusingSpecificParams,
bool>; // should use decompression implementation
bool>; // should use decompression implementation
class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeightsDecompressionParams>,
virtual public SubgraphBaseTest,
public CpuTestWithFusing {
public:
static std::string getTestCaseName(testing::TestParamInfo<MatmulWeightsDecompressionParams> obj) {
std::vector<InputShape> inputShapes;
ShapeParams shape_params;
ov::test::ElementType weights_precision;
bool transpose;
bool decompression_sub;
@ -54,7 +67,7 @@ public:
fusingSpecificParams fusing_params;
bool should_fuse;
std::tie(inputShapes,
std::tie(shape_params,
weights_precision,
transpose,
decompression_sub,
@ -64,20 +77,9 @@ public:
should_fuse) = obj.param;
std::ostringstream result;
for (const auto& shape : inputShapes) {
result << ov::test::utils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
if (!shape.second.empty()) {
auto itr = shape.second.begin();
do {
result << ov::test::utils::vec2str(*itr);
} while (++itr != shape.second.end() && result << "_");
}
result << ")_";
}
result << "data_shape=" << shape_params.data_shape << "_";
result << "weights_shape=" << shape_params.weights_shape << "_";
result << "group_size=" << shape_params.weights_group_size << "_";
result << "weights_precision=" << weights_precision << "_";
result << "transpose_weights=" << transpose << "_";
result << "decompression_subtract=" << decompression_sub << "_";
@ -94,34 +96,64 @@ public:
}
protected:
std::shared_ptr<ov::Model> initSubgraph(std::vector<ov::PartialShape>& inputShapes,
const ov::element::Type data_precision,
const ov::element::Type weights_precision,
const bool transpose_weights,
const bool add_subtract,
const bool reshape_on_decompression) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(data_precision, inputShapes[0])};
std::shared_ptr<ov::Node> initDecompressionWeights(const ov::Shape& weights_shape,
const int group_size,
const ov::element::Type data_precision,
const ov::element::Type weights_precision,
const bool transpose_weights,
const bool add_subtract,
const bool reshape_on_decompression_constant) {
auto transpose_if_necessary = [&](const ov::Shape& shape) {
if (!transpose_weights)
return shape;
auto transposed_shape = shape;
std::swap(*transposed_shape.rbegin(), *(transposed_shape.rbegin() + 1));
return transposed_shape;
auto result_shape = shape;
if (transpose_weights)
std::swap(*result_shape.rbegin(), *(result_shape.rbegin() + 1));
return result_shape;
};
auto weights_shape = transpose_if_necessary(inputShapes[1].to_shape());
auto weights = ngraph::builder::makeConstant<uint8_t>(weights_precision, weights_shape, {}, true);
const bool group_decompression = group_size != -1;
// Weights has shape [I, O], where
// I - input channels
// O - output channels
// In case of group decompression, input channels dimension is split into 2: I -> [N, G], where
// N - number of groups
// G - group size
auto transformed_weights_shape = transpose_if_necessary(weights_shape);
if (group_decompression) {
OPENVINO_ASSERT(weights_shape[0] % group_size == 0,
"Weights output channels count (",
weights_shape[0],
") must be divisible by decompression group size (",
group_size,
").");
auto in_channel_idx = transpose_weights ? transformed_weights_shape.size() - 1 : transformed_weights_shape.size() - 2;
transformed_weights_shape[in_channel_idx] = weights_shape[0] / group_size;
transformed_weights_shape.insert(transformed_weights_shape.begin() + in_channel_idx + 1, group_size);
}
auto weights = ngraph::builder::makeConstant<uint8_t>(weights_precision, transformed_weights_shape, {}, true);
weights->set_friendly_name("Compressed_weights");
auto weights_convert = std::make_shared<ngraph::opset1::Convert>(weights, data_precision);
std::shared_ptr<ov::Node> mul_parent = weights_convert;
auto output_channels = transpose_weights ? *(weights_shape.rbegin() + 1) : *weights_shape.rbegin();
auto scaleshift_target_shape = transpose_if_necessary(ov::Shape{1, output_channels});
auto scaleshift_const_shape = reshape_on_decompression ? ov::Shape{output_channels} : scaleshift_target_shape;
auto output_channels = *weights_shape.rbegin();
// Decompression constants shape:
// Ordinary decompression: [O, 1]
// Group decompression: [O, N, 1]
ov::Shape scaleshift_target_shape{output_channels};
scaleshift_target_shape.insert(scaleshift_target_shape.begin(), group_decompression ? weights_shape[0] / group_size : 1);
scaleshift_target_shape = transpose_if_necessary(scaleshift_target_shape);
if (group_decompression) {
auto in_channel_idx = transpose_weights ? scaleshift_target_shape.size() - 1 : scaleshift_target_shape.size() - 2;
scaleshift_target_shape.insert(scaleshift_target_shape.begin() + in_channel_idx + 1, 1);
}
auto scaleshift_const_shape = scaleshift_target_shape;
if (reshape_on_decompression_constant)
scaleshift_const_shape.erase(std::remove(scaleshift_const_shape.begin(), scaleshift_const_shape.end(), 1), scaleshift_const_shape.end());
if (add_subtract) {
auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true);
std::shared_ptr<ov::Node> shift_convert = std::make_shared<ngraph::opset1::Convert>(shift_const, data_precision);
if (reshape_on_decompression) {
if (reshape_on_decompression_constant) {
auto shift_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape);
auto shift_reshape = std::make_shared<ov::opset10::Reshape>(shift_convert, shift_reshape_const, false);
shift_convert = shift_reshape;
@ -130,31 +162,54 @@ protected:
}
std::shared_ptr<ov::Node> scale_const = ngraph::builder::makeConstant<float>(data_precision, scaleshift_const_shape, {}, true);
if (reshape_on_decompression) {
if (reshape_on_decompression_constant) {
auto scale_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape);
auto scale_reshape = std::make_shared<ov::opset10::Reshape>(scale_const, scale_reshape_const, false);
scale_const = scale_reshape;
}
auto multiply = std::make_shared<ov::opset10::Multiply>(mul_parent, scale_const);
std::shared_ptr<ov::Node> last_node = std::make_shared<ov::opset10::Multiply>(mul_parent, scale_const);
std::shared_ptr<ov::Node> matmul_weights = multiply;
if (group_decompression) {
auto reshape_target_shape = transpose_weights ? std::vector<int>{-1, static_cast<int>(weights_shape[0])}
: std::vector<int>{static_cast<int>(weights_shape[0]), -1};
auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {reshape_target_shape.size()}, reshape_target_shape);
last_node = std::make_shared<ov::opset10::Reshape>(last_node, target_shape_node, false);
}
if (transpose_weights) {
const size_t rank = matmul_weights->get_output_partial_shape(0).size();
const size_t rank = last_node->get_output_partial_shape(0).size();
std::vector<int> order(rank);
std::iota(order.begin(), order.end(), 0);
std::swap(*order.rbegin(), *(order.rbegin() + 1));
auto transpose_constant = ov::opset10::Constant::create(ov::element::i32, {rank}, order);
auto transpose = std::make_shared<ov::opset10::Transpose>(matmul_weights, transpose_constant);
matmul_weights = transpose;
last_node = std::make_shared<ov::opset10::Transpose>(last_node, transpose_constant);
}
auto matMul = builder::makeMatMul(params[0], matmul_weights);
return last_node;
}
std::shared_ptr<ov::Model> initSubgraph(const ov::PartialShape& data_shape,
const ov::Shape& weights_shape,
const int group_size,
const ov::element::Type data_precision,
const ov::element::Type weights_precision,
const bool transpose_weights,
const bool add_subtract,
const bool reshape_on_decompression) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(data_precision, data_shape)};
const auto weights_subgraph = initDecompressionWeights(weights_shape,
group_size,
data_precision,
weights_precision,
transpose_weights,
add_subtract,
reshape_on_decompression);
auto matMul = builder::makeMatMul(params[0], weights_subgraph);
return makeNgraphFunction(data_precision, params, matMul, "MatmulWeightsDecompression");
}
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
std::vector<InputShape> inputShapes;
ShapeParams shape_params;
ov::test::ElementType weights_precision;
bool transpose_weights;
bool decompression_sub;
@ -163,7 +218,7 @@ protected:
fusingSpecificParams fusing_params;
bool should_fuse;
std::tie(inputShapes,
std::tie(shape_params,
weights_precision,
transpose_weights,
decompression_sub,
@ -174,25 +229,38 @@ protected:
configuration.insert(additional_config.begin(), additional_config.end());
std::tie(postOpMgrPtr, fusedOps) = fusing_params;
init_input_shapes(inputShapes);
init_input_shapes({shape_params.data_shape, {{}, {{shape_params.weights_shape}}}});
ElementType netType = element::f32;
inType = outType = netType;
function = initSubgraph(inputDynamicShapes, netType, weights_precision, transpose_weights, decompression_sub, reshape_on_decompression);
function = initSubgraph(inputDynamicShapes[0],
shape_params.weights_shape,
shape_params.weights_group_size,
netType,
weights_precision,
transpose_weights,
decompression_sub,
reshape_on_decompression);
}
void checkResults() {
const auto& test_param = GetParam();
ov::test::ElementType weights_precision = std::get<1>(test_param);
bool should_fuse = std::get<7>(test_param);
const auto& weights_precision = std::get<1>(test_param);
// TODO: remove this condition when group decompression is supported
if (weights_precision == ov::element::nf4 || std::get<0>(test_param).weights_group_size != -1) {
return;
}
bool weights_found = false;
for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) {
if (n->get_friendly_name() == "Compressed_weights") {
ASSERT_EQ(n->get_output_element_type(0), weights_precision);
weights_found = true;
}
}
ASSERT_TRUE(weights_found);
std::map<std::string, std::string> additional_config = std::get<5>(test_param);
const bool should_fuse = std::get<7>(test_param);
const size_t expected_count = should_fuse ? 0 : 1;
CheckNumberOfNodesWithType(compiledModel, "Convert", expected_count);
CheckNumberOfNodesWithType(compiledModel, "Eltwise", expected_count);
@ -235,22 +303,24 @@ bool shouldUseDecompressionKernelBasic() {
return shouldUseDecompressionKernelBig();
}
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8};
const std::vector<std::vector<InputShape>> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {{}, {{16, 32}}}},
{{{}, {{1, 4, 16}}}, {{}, {{1, 16, 32}}}},
{{{}, {{10, 40, 496}}}, {{}, {{1, 496, 240}}}},
{{{}, {{1, 4, 48}}}, {{}, {{48, 256}}}},
{{{}, {{11, 339, 377}}}, {{}, {{377, 335}}}},
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8, ov::element::nf4};
const std::vector<ShapeParams> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {16, 32}},
{{{}, {{1, 4, 16}}}, {16, 32}, 2ul},
{{{}, {{1, 4, 16}}}, {1, 16, 32}},
{{{}, {{10, 40, 496}}}, {1, 496, 240}},
{{{}, {{1, 4, 48}}}, {48, 256}},
{{{}, {{11, 339, 377}}}, {377, 335}},
};
const std::vector<std::vector<InputShape>> input_shapes_big = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {{}, {{1, 480, 256}}}},
{{{}, {{1, 4, 32}}}, {{}, {{32, 256}}}},
{{{}, {{1, 4, 512}}}, {{}, {{512, 256}}}},
{{{}, {{1, 16, 32}}}, {{}, {{32, 64}}}},
{{{}, {{2, 4, 32}}}, {{}, {{32, 65}}}},
{{{}, {{3, 12, 768}}}, {{}, {{768, 1024}}}},
{{{}, {{11, 339, 577}}}, {{}, {{577, 335}}}},
const std::vector<ShapeParams> input_shapes_big = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
{{{-1, 1, 4096}, {{1, 1, 4096}}}, {4096, 3840}, 128ul},
{{{}, {{1, 4, 32}}}, {32, 256}},
{{{}, {{1, 4, 512}}}, {512, 256}},
{{{}, {{1, 16, 32}}}, {32, 64}},
{{{}, {{2, 4, 32}}}, {32, 65}},
{{{}, {{3, 12, 768}}}, {768, 1024}},
{{{}, {{11, 339, 577}}}, {577, 335}},
};
const std::vector<fusingSpecificParams> fusingParamsSet {
emptyFusingSpec,
@ -281,12 +351,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_big,
::testing::Values(shouldUseDecompressionKernelBig())),
MatmulWeightsDecompression::getTestCaseName);
const std::vector<std::vector<InputShape>> input_shapes_corner_cases_basic = {
{{{-1, -1, -1}, {{1, 4, 16}}}, {{}, {{1, 16, 32}}}},
{{{-1, -1, -1}, {{1, 4, 16}}}, {{}, {{16, 32}}}},
const std::vector<ShapeParams> input_shapes_corner_cases_basic = {
{{{-1, -1, -1}, {{1, 4, 16}}}, {1, 16, 32}},
{{{-1, -1, -1}, {{1, 4, 16}}}, {16, 32}},
{{{-1, -1, -1}, {{1, 4, 16}}}, {16, 32}, 4ul},
};
const std::vector<std::vector<InputShape>> input_shapes_corner_cases_big = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {{}, {{1, 480, 256}}}},
const std::vector<ShapeParams> input_shapes_corner_cases_big = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {1, 480, 256}},
{{{-1, -1, -1}, {{1, 1, 4096}}}, {4096, 4096}, 128ul},
};
const std::vector<bool> transpose_weights = {true, false};
@ -317,5 +389,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_big,
::testing::Values(shouldUseDecompressionKernelBig())),
MatmulWeightsDecompression::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@ -70,6 +70,7 @@ std::shared_ptr<ov::Node> makeConstant(const ov::element::Type& type,
makeNode(ov::element::Type_t::u32);
makeNode(ov::element::Type_t::u64);
makeNode(ov::element::Type_t::boolean);
makeNode(ov::element::Type_t::nf4);
#undef makeNode
default:
throw std::runtime_error("Unhandled precision");