[CPU] FullyConnected: sparse weights fix (#20117)
This commit is contained in:
parent
67a62186ee
commit
c417d15432
@ -52,6 +52,7 @@ struct FCKey {
|
|||||||
dnnl::primitive_attr attr;
|
dnnl::primitive_attr attr;
|
||||||
impl_desc_type implType;
|
impl_desc_type implType;
|
||||||
bool useConv1x1;
|
bool useConv1x1;
|
||||||
|
bool useSparseWeights;
|
||||||
|
|
||||||
size_t hash() const;
|
size_t hash() const;
|
||||||
bool operator==(const FCKey& rhs) const;
|
bool operator==(const FCKey& rhs) const;
|
||||||
@ -72,6 +73,7 @@ size_t FCKey::hash() const {
|
|||||||
seed = hash_combine(seed, get_attr_hash(*attr.get()));
|
seed = hash_combine(seed, get_attr_hash(*attr.get()));
|
||||||
seed = hash_combine(seed, implType);
|
seed = hash_combine(seed, implType);
|
||||||
seed = hash_combine(seed, useConv1x1);
|
seed = hash_combine(seed, useConv1x1);
|
||||||
|
seed = hash_combine(seed, useSparseWeights);
|
||||||
return seed;
|
return seed;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,7 +92,7 @@ bool FCKey::operator==(const FCKey &rhs) const {
|
|||||||
retVal = retVal && out && rhs.out && out->getDnnlDesc() == rhs.out->getDnnlDesc();
|
retVal = retVal && out && rhs.out && out->getDnnlDesc() == rhs.out->getDnnlDesc();
|
||||||
}
|
}
|
||||||
retVal = retVal && *attr.get() == *rhs.attr.get() &&
|
retVal = retVal && *attr.get() == *rhs.attr.get() &&
|
||||||
implType == rhs.implType && useConv1x1 == rhs.useConv1x1;
|
implType == rhs.implType && useConv1x1 == rhs.useConv1x1 && useSparseWeights == rhs.useSparseWeights;
|
||||||
return retVal;
|
return retVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,15 +418,20 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en
|
|||||||
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
|
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
|
||||||
outDesc = outDesc.reshape(normalizedOutDims);
|
outDesc = outDesc.reshape(normalizedOutDims);
|
||||||
}
|
}
|
||||||
auto wghDescAny = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()),
|
dnnl::memory::desc weiDesc;
|
||||||
key.inp1->getDataType(), memory::format_tag::any);
|
if (key.useSparseWeights) {
|
||||||
|
weiDesc = key.inp1->getDnnlDesc();
|
||||||
|
} else {
|
||||||
|
weiDesc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp1->getShape().getStaticDims()),
|
||||||
|
key.inp1->getDataType(), memory::format_tag::any);
|
||||||
|
}
|
||||||
dnnl::inner_product_forward::primitive_desc prim_desc;
|
dnnl::inner_product_forward::primitive_desc prim_desc;
|
||||||
if (key.bias) {
|
if (key.bias) {
|
||||||
prim_desc = dnnl::inner_product_forward::primitive_desc(
|
prim_desc = dnnl::inner_product_forward::primitive_desc(
|
||||||
engine,
|
engine,
|
||||||
dnnl::prop_kind::forward_inference,
|
dnnl::prop_kind::forward_inference,
|
||||||
inDesc,
|
inDesc,
|
||||||
wghDescAny,
|
weiDesc,
|
||||||
key.bias->getDnnlDesc(),
|
key.bias->getDnnlDesc(),
|
||||||
outDesc,
|
outDesc,
|
||||||
key.attr);
|
key.attr);
|
||||||
@ -433,7 +440,7 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en
|
|||||||
engine,
|
engine,
|
||||||
dnnl::prop_kind::forward_inference,
|
dnnl::prop_kind::forward_inference,
|
||||||
inDesc,
|
inDesc,
|
||||||
wghDescAny,
|
weiDesc,
|
||||||
outDesc,
|
outDesc,
|
||||||
key.attr);
|
key.attr);
|
||||||
}
|
}
|
||||||
@ -542,7 +549,8 @@ void FullyConnected::prepareParams() {
|
|||||||
outDesc,
|
outDesc,
|
||||||
attr,
|
attr,
|
||||||
implementationTypeIP,
|
implementationTypeIP,
|
||||||
useConv1x1};
|
useConv1x1,
|
||||||
|
useSparseWeights};
|
||||||
|
|
||||||
auto& engine = getEngine();
|
auto& engine = getEngine();
|
||||||
|
|
||||||
@ -597,7 +605,8 @@ void FullyConnected::prepareParams() {
|
|||||||
// changed shapes may also cause the kernel type changed
|
// changed shapes may also cause the kernel type changed
|
||||||
selected_pd->setImplementationType(execPtr->getImplementationType());
|
selected_pd->setImplementationType(execPtr->getImplementationType());
|
||||||
// WA: We update implType to know whether weights decompression was used inside the kernel
|
// WA: We update implType to know whether weights decompression was used inside the kernel
|
||||||
if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && useSparseWeights) {
|
if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx &&
|
||||||
|
execPtr->getDnnlWeightDesc().get_format_kind() == memory::format_kind::sparsed) {
|
||||||
selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx);
|
selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx);
|
||||||
}
|
}
|
||||||
// maybe expected 1x1 conv is not created, update the flag depends on the real type
|
// maybe expected 1x1 conv is not created, update the flag depends on the real type
|
||||||
@ -960,7 +969,7 @@ std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(const dnnl::primitive_
|
|||||||
|
|
||||||
if (getInputShapeAtPort(idx).getRank() == 3
|
if (getInputShapeAtPort(idx).getRank() == 3
|
||||||
// report original plain layout for weight since it needs to be reordered dynamically at runtime
|
// report original plain layout for weight since it needs to be reordered dynamically at runtime
|
||||||
|| idx == 1) {
|
|| (idx == 1 && !useSparseWeights)) {
|
||||||
return std::make_shared<CpuBlockedMemoryDesc>(
|
return std::make_shared<CpuBlockedMemoryDesc>(
|
||||||
DnnlExtensionUtils::DataTypeToIEPrecision(desc.get_data_type()), getInputShapeAtPort(idx));
|
DnnlExtensionUtils::DataTypeToIEPrecision(desc.get_data_type()), getInputShapeAtPort(idx));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user