[CPU] FullyConnected: sparse weights fix (#20117)

This commit is contained in:
Anton Voronov 2023-10-10 10:24:38 +04:00 committed by GitHub
parent 67a62186ee
commit c417d15432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,6 +52,7 @@ struct FCKey {
dnnl::primitive_attr attr; dnnl::primitive_attr attr;
impl_desc_type implType; impl_desc_type implType;
bool useConv1x1; bool useConv1x1;
bool useSparseWeights;
size_t hash() const; size_t hash() const;
bool operator==(const FCKey& rhs) const; bool operator==(const FCKey& rhs) const;
@ -72,6 +73,7 @@ size_t FCKey::hash() const {
seed = hash_combine(seed, get_attr_hash(*attr.get())); seed = hash_combine(seed, get_attr_hash(*attr.get()));
seed = hash_combine(seed, implType); seed = hash_combine(seed, implType);
seed = hash_combine(seed, useConv1x1); seed = hash_combine(seed, useConv1x1);
seed = hash_combine(seed, useSparseWeights);
return seed; return seed;
} }
@ -90,7 +92,7 @@ bool FCKey::operator==(const FCKey &rhs) const {
retVal = retVal && out && rhs.out && out->getDnnlDesc() == rhs.out->getDnnlDesc(); retVal = retVal && out && rhs.out && out->getDnnlDesc() == rhs.out->getDnnlDesc();
} }
retVal = retVal && *attr.get() == *rhs.attr.get() && retVal = retVal && *attr.get() == *rhs.attr.get() &&
implType == rhs.implType && useConv1x1 == rhs.useConv1x1; implType == rhs.implType && useConv1x1 == rhs.useConv1x1 && useSparseWeights == rhs.useSparseWeights;
return retVal; return retVal;
} }
@ -416,15 +418,20 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] }; auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
outDesc = outDesc.reshape(normalizedOutDims); outDesc = outDesc.reshape(normalizedOutDims);
} }
auto wghDescAny = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()), dnnl::memory::desc weiDesc;
key.inp1->getDataType(), memory::format_tag::any); if (key.useSparseWeights) {
weiDesc = key.inp1->getDnnlDesc();
} else {
weiDesc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()),
key.inp1->getDataType(), memory::format_tag::any);
}
dnnl::inner_product_forward::primitive_desc prim_desc; dnnl::inner_product_forward::primitive_desc prim_desc;
if (key.bias) { if (key.bias) {
prim_desc = dnnl::inner_product_forward::primitive_desc( prim_desc = dnnl::inner_product_forward::primitive_desc(
engine, engine,
dnnl::prop_kind::forward_inference, dnnl::prop_kind::forward_inference,
inDesc, inDesc,
wghDescAny, weiDesc,
key.bias->getDnnlDesc(), key.bias->getDnnlDesc(),
outDesc, outDesc,
key.attr); key.attr);
@ -433,7 +440,7 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en
engine, engine,
dnnl::prop_kind::forward_inference, dnnl::prop_kind::forward_inference,
inDesc, inDesc,
wghDescAny, weiDesc,
outDesc, outDesc,
key.attr); key.attr);
} }
@ -542,7 +549,8 @@ void FullyConnected::prepareParams() {
outDesc, outDesc,
attr, attr,
implementationTypeIP, implementationTypeIP,
useConv1x1}; useConv1x1,
useSparseWeights};
auto& engine = getEngine(); auto& engine = getEngine();
@ -597,7 +605,8 @@ void FullyConnected::prepareParams() {
// changed shapes may also cause the kernel type changed // changed shapes may also cause the kernel type changed
selected_pd->setImplementationType(execPtr->getImplementationType()); selected_pd->setImplementationType(execPtr->getImplementationType());
// WA: We update implType to know whether weights decompression was used inside the kernel // WA: We update implType to know whether weights decompression was used inside the kernel
if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && useSparseWeights) { if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx &&
execPtr->getDnnlWeightDesc().get_format_kind() == memory::format_kind::sparsed) {
selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx); selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx);
} }
// maybe expected 1x1 conv is not created, update the flag depends on the real type // maybe expected 1x1 conv is not created, update the flag depends on the real type
@ -960,7 +969,7 @@ std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(const dnnl::primitive_
if (getInputShapeAtPort(idx).getRank() == 3 if (getInputShapeAtPort(idx).getRank() == 3
// report original plain layout for weight since it needs to be reordered dynamically at runtime // report original plain layout for weight since it needs to be reordered dynamically at runtime
|| idx == 1) { || (idx == 1 && !useSparseWeights)) {
return std::make_shared<CpuBlockedMemoryDesc>( return std::make_shared<CpuBlockedMemoryDesc>(
DnnlExtensionUtils::DataTypeToIEPrecision(desc.get_data_type()), getInputShapeAtPort(idx)); DnnlExtensionUtils::DataTypeToIEPrecision(desc.get_data_type()), getInputShapeAtPort(idx));
} }