[CPU] Sparsity weights decompression (#13775)
This commit is contained in:
parent
cf8c78483d
commit
3caebffa5c
@ -71,6 +71,9 @@ void regmodule_properties(py::module m) {
|
||||
|
||||
// Submodule intel_cpu property
|
||||
wrap_property_RW(m_intel_cpu, ov::intel_cpu::denormals_optimization, "denormals_optimization");
|
||||
wrap_property_RW(m_intel_cpu,
|
||||
ov::intel_cpu::sparse_weights_decompression_rate,
|
||||
"sparse_weights_decompression_rate");
|
||||
|
||||
// Submodule device
|
||||
py::module m_device =
|
||||
|
@ -40,5 +40,7 @@ namespace CPUConfigParams {
|
||||
*/
|
||||
DECLARE_CPU_CONFIG_KEY(DENORMALS_OPTIMIZATION);
|
||||
|
||||
DECLARE_CPU_CONFIG_KEY(SPARSE_WEIGHTS_DECOMPRESSION_RATE);
|
||||
|
||||
} // namespace CPUConfigParams
|
||||
} // namespace InferenceEngine
|
||||
|
@ -47,5 +47,7 @@ namespace intel_cpu {
|
||||
*/
|
||||
static constexpr Property<bool> denormals_optimization{"CPU_DENORMALS_OPTIMIZATION"};
|
||||
|
||||
static constexpr Property<float> sparse_weights_decompression_rate{"SPARSE_WEIGHTS_DECOMPRESSION_RATE"};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -74,6 +74,20 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
|
||||
// zero and any negative value will be treated
|
||||
// as default batch size
|
||||
batchLimit = std::max(val_i, 0);
|
||||
} else if (key == CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE) {
|
||||
float val_f = 0.0f;
|
||||
try {
|
||||
val_f = std::stof(val);
|
||||
} catch (const std::exception&) {
|
||||
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
|
||||
<< ". Expected only float numbers";
|
||||
}
|
||||
if (val_f < 0.f || val_f > 1.f) {
|
||||
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
|
||||
<< ". Sparse rate must be in range [0.0f,1.0f]";
|
||||
} else {
|
||||
fcSparseWeiDecompressionRate = val_f;
|
||||
}
|
||||
} else if (key == PluginConfigParams::KEY_PERF_COUNT) {
|
||||
if (val == PluginConfigParams::YES) collectPerfCounters = true;
|
||||
else if (val == PluginConfigParams::NO) collectPerfCounters = false;
|
||||
|
@ -34,6 +34,7 @@ struct Config {
|
||||
bool enableDynamicBatch = false;
|
||||
std::string dumpToDot = "";
|
||||
int batchLimit = 0;
|
||||
float fcSparseWeiDecompressionRate = 1.0f;
|
||||
size_t rtCacheCapacity = 5000ul;
|
||||
InferenceEngine::IStreamsExecutor::Config streamExecutorConfig;
|
||||
InferenceEngine::PerfHintsConfig perfHintsConfig;
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <nodes/reorder.h>
|
||||
#include "nodes/convert.h"
|
||||
#include "nodes/subgraph.h"
|
||||
#include "nodes/fullyconnected.h"
|
||||
|
||||
#include <ie_algorithm.hpp>
|
||||
#include <blob_factory.hpp>
|
||||
@ -340,6 +341,9 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
|
||||
if (config.enforceBF16)
|
||||
EnforceBF16();
|
||||
|
||||
if (config.fcSparseWeiDecompressionRate < 1.0f)
|
||||
setMinSparseRate(config.fcSparseWeiDecompressionRate);
|
||||
|
||||
auto hasSubgraphConsumers = [] (const NodePtr& node) -> bool {
|
||||
const auto & childEdges = node->getChildEdges();
|
||||
return std::any_of(childEdges.begin(), childEdges.end(),
|
||||
@ -1612,6 +1616,14 @@ void Graph::EnforceBF16() {
|
||||
}
|
||||
}
|
||||
|
||||
void Graph::setMinSparseRate(float minSparseRate) {
|
||||
for (const auto &node : graphNodes) {
|
||||
if (auto fcNodePtr = std::dynamic_pointer_cast<node::FullyConnected>(node)) {
|
||||
fcNodePtr->setMinSparseRate(minSparseRate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> Graph::dump() const {
|
||||
return dump_graph_as_ie_ngraph_net(*this);
|
||||
}
|
||||
|
@ -269,6 +269,7 @@ private:
|
||||
std::unordered_map<Node*, size_t> syncNodesInds;
|
||||
|
||||
void EnforceBF16();
|
||||
void setMinSparseRate(float minSparseRate);
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -418,6 +418,7 @@ std::string Node::getPrimitiveDescriptorType() {
|
||||
SEARCH_TYPE(uni);
|
||||
|
||||
SEARCH_TYPE(winograd);
|
||||
SEARCH_TYPE(sparse);
|
||||
SEARCH_TYPE(_dw);
|
||||
SEARCH_TYPE(_1x1);
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "fullyconnected.h"
|
||||
#include "eltwise.h"
|
||||
#include "input.h"
|
||||
#include "fake_quantize.h"
|
||||
#include "input.h"
|
||||
#include "reorder.h"
|
||||
@ -22,6 +23,7 @@
|
||||
#include <common/primitive_desc.hpp>
|
||||
#include <common/primitive_desc_iface.hpp>
|
||||
#include "onednn/dnnl.h"
|
||||
#include "cpu/x64/cpu_isa_traits.hpp"
|
||||
|
||||
using namespace dnnl;
|
||||
using namespace InferenceEngine;
|
||||
@ -172,6 +174,8 @@ void FullyConnected::getSupportedDescriptors() {
|
||||
if (getChildEdges().empty())
|
||||
IE_THROW()<< errorPrefix << " has incorrect number of output edges";
|
||||
|
||||
useSparseWeights = useSparseWeightsDecompression();
|
||||
|
||||
auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID));
|
||||
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID));
|
||||
|
||||
@ -360,6 +364,10 @@ void FullyConnected::prepareParams() {
|
||||
}
|
||||
// changed shapes may also cause the kernel type changed
|
||||
selected_pd->setImplementationType(execPtr->getImplementationType());
|
||||
// WA: We update implType to know whether weights decompression was used inside the kernel
|
||||
if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && useSparseWeights) {
|
||||
selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx);
|
||||
}
|
||||
// maybe expected 1x1 conv is not created, update the flag depends on the real type
|
||||
useConv1x1 = execPtr->getImplementationType() == brgconv_avx512_1x1;
|
||||
|
||||
@ -503,6 +511,7 @@ bool FullyConnected::created() const {
|
||||
const std::vector<impl_desc_type>& FullyConnected::getPrimitivesPriority() {
|
||||
std::vector<impl_desc_type> priorities = {
|
||||
impl_desc_type::unknown,
|
||||
impl_desc_type::brgemm_sparse_avx512_amx,
|
||||
impl_desc_type::brgemm_avx512_amx,
|
||||
impl_desc_type::brgemm_avx512,
|
||||
impl_desc_type::gemm_blas,
|
||||
@ -578,9 +587,15 @@ void FullyConnected::createDescriptorInternal(const dnnl::memory::desc &inputDes
|
||||
DnnlExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
|
||||
}
|
||||
|
||||
dnnl::memory::desc wgh_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
|
||||
wdt, dnnl::memory::format_tag::any);
|
||||
|
||||
// We need to explicitly specify the memory descriptor to use sparse weights decompression
|
||||
dnnl::memory::desc wgh_candidate;
|
||||
if (useSparseWeights) {
|
||||
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
|
||||
wdt, memory::desc::packed(nnzCount) };
|
||||
} else {
|
||||
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
|
||||
wdt, dnnl::memory::format_tag::any };
|
||||
}
|
||||
if (withBiases) {
|
||||
dnnl::memory::desc bias_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(BIAS_ID).getStaticDims()), bdt,
|
||||
dnnl::memory::format_tag::any);
|
||||
@ -634,7 +649,7 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
|
||||
portConfig.inPlace(-1);
|
||||
portConfig.constant(false);
|
||||
auto desc = getSrcMemDesc(itpd, i);
|
||||
if (supportsUndefStridesAndOffset()) {
|
||||
if (supportsUndefStridesAndOffset() && !(i == WEIGHTS_ID && useSparseWeights)) {
|
||||
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
|
||||
} else {
|
||||
portConfig.setMemDesc(desc);
|
||||
@ -868,6 +883,65 @@ MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
bool FullyConnected::useSparseWeightsDecompression() {
|
||||
// minSparseRate == 1 means that sparse feature is switched off
|
||||
if (minSparseRate == 1.f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx))
|
||||
return false;
|
||||
|
||||
auto weiDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
|
||||
if (weiDims.size() != 2 || weiDims[0] % 64 != 0 || weiDims[1] % 64 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);
|
||||
auto weightsPrecision = getOriginalInputPrecisionAtPort(WEIGHTS_ID);
|
||||
if (!one_of(inputPrecision , Precision::U8, Precision::I8) || weightsPrecision != Precision::I8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// calculate sparse rate
|
||||
const auto constNode = std::dynamic_pointer_cast<Input>(getParentEdgeAt(WEIGHTS_ID)->getParent());
|
||||
if (!constNode) {
|
||||
return false;
|
||||
}
|
||||
auto blb = constNode->getMemoryPtr();
|
||||
if (blb == nullptr)
|
||||
IE_THROW() << "Cannot get const blob for node " << getName() << ".";
|
||||
|
||||
auto weightsData = reinterpret_cast<const int8_t*>(blb->GetPtr());
|
||||
auto elementsCount = blb->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
|
||||
size_t zerosCounts = 0;
|
||||
for (int i = 0; i < elementsCount; i++) {
|
||||
if (weightsData[i] == 0) {
|
||||
zerosCounts++;
|
||||
}
|
||||
}
|
||||
nnzCount = elementsCount - zerosCounts;
|
||||
|
||||
DEBUG_LOG(getName(), ", weightsData.size() = ", elementsCount, ", zerosCounts = ",
|
||||
zerosCounts, ", nnzCount = ", nnzCount);
|
||||
|
||||
weiSparseRate = static_cast<float>(zerosCounts) / static_cast<float>(elementsCount);
|
||||
|
||||
// [av] WA: there is no point in using sparse decompression when the sparse rate is low
|
||||
// todo: add heuristic
|
||||
if (minSparseRate < 0.5)
|
||||
minSparseRate = 0.5;
|
||||
|
||||
DEBUG_LOG(getName(), " | sparse rate = ", weiSparseRate * 100, "%, min sparse rate = ",
|
||||
minSparseRate * 100, "%, use sparse weights = ", weiSparseRate >= minSparseRate);
|
||||
|
||||
if (weiSparseRate < minSparseRate) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -42,6 +42,7 @@ public:
|
||||
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void initOptimalPrimitiveDescriptor() override;
|
||||
// void createPrimitive() override;
|
||||
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
|
||||
std::shared_ptr<MemoryDesc> getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
|
||||
|
||||
@ -58,6 +59,8 @@ public:
|
||||
|
||||
void setDynamicBatchLim(int lim) override;
|
||||
|
||||
void setMinSparseRate(float sparseRate) { minSparseRate = sparseRate; }
|
||||
|
||||
private:
|
||||
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
|
||||
const dnnl::memory::desc &outputDesc);
|
||||
@ -106,6 +109,13 @@ private:
|
||||
|
||||
bool canBeExecutedInConv1x1() const;
|
||||
MemoryPtr prepareWeightMemory(const DnnlMemoryDescPtr weightDesc);
|
||||
|
||||
// sparse weights
|
||||
bool useSparseWeights = false;
|
||||
int nnzCount = -1;
|
||||
float minSparseRate = 1.f;
|
||||
float weiSparseRate = 0.f;
|
||||
bool useSparseWeightsDecompression();
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
|
@ -37,6 +37,7 @@ impl_desc_type parse_impl_name(std::string impl_desc_name) {
|
||||
SEARCH_WORD(_1x1);
|
||||
SEARCH_WORD(_dw);
|
||||
SEARCH_WORD(reorder);
|
||||
SEARCH_WORD(sparse);
|
||||
if ((res & impl_desc_type::avx2) != impl_desc_type::avx2 &&
|
||||
(res & impl_desc_type::avx512) != impl_desc_type::avx512)
|
||||
SEARCH_WORD(avx);
|
||||
@ -108,6 +109,7 @@ const char* impl_type_to_string(impl_desc_type type) {
|
||||
CASE(brgemm_sse42);
|
||||
CASE(brgemm_uni);
|
||||
CASE(brgemm_avx512_amx);
|
||||
CASE(brgemm_sparse_avx512_amx);
|
||||
|
||||
#undef CASE
|
||||
return "unknown";
|
||||
|
@ -35,6 +35,8 @@ enum impl_desc_type {
|
||||
reorder = 1<<22,
|
||||
// winograd
|
||||
winograd = 1<<23,
|
||||
// sparse
|
||||
sparse = 1<<24,
|
||||
|
||||
// real types
|
||||
ref_any = ref | any,
|
||||
@ -90,6 +92,7 @@ enum impl_desc_type {
|
||||
brgemm_sse42 = brgemm | sse42,
|
||||
brgemm_uni = brgemm | uni,
|
||||
brgemm_avx512_amx = brgemm | avx512 | amx,
|
||||
brgemm_sparse_avx512_amx = brgemm | sparse | avx512 | amx,
|
||||
};
|
||||
|
||||
const char * impl_type_to_string(impl_desc_type type);
|
||||
|
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
@ -1 +1 @@
|
||||
Subproject commit dbe732c9102fa61f5eae58028215cfc571904baf
|
||||
Subproject commit be8d5d21994bf8495a04ee58da3f0d566e695db5
|
Loading…
Reference in New Issue
Block a user