[CPU] Sparsity weights decompression (#13775)

This commit is contained in:
Anton Voronov 2022-12-04 22:43:53 +04:00 committed by GitHub
parent cf8c78483d
commit 3caebffa5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 130 additions and 5 deletions

View File

@ -71,6 +71,9 @@ void regmodule_properties(py::module m) {
// Submodule intel_cpu property
wrap_property_RW(m_intel_cpu, ov::intel_cpu::denormals_optimization, "denormals_optimization");
wrap_property_RW(m_intel_cpu,
ov::intel_cpu::sparse_weights_decompression_rate,
"sparse_weights_decompression_rate");
// Submodule device
py::module m_device =

View File

@ -40,5 +40,7 @@ namespace CPUConfigParams {
*/
DECLARE_CPU_CONFIG_KEY(DENORMALS_OPTIMIZATION);
DECLARE_CPU_CONFIG_KEY(SPARSE_WEIGHTS_DECOMPRESSION_RATE);
} // namespace CPUConfigParams
} // namespace InferenceEngine

View File

@ -47,5 +47,7 @@ namespace intel_cpu {
*/
static constexpr Property<bool> denormals_optimization{"CPU_DENORMALS_OPTIMIZATION"};
static constexpr Property<float> sparse_weights_decompression_rate{"SPARSE_WEIGHTS_DECOMPRESSION_RATE"};
} // namespace intel_cpu
} // namespace ov

View File

@ -74,6 +74,20 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
// zero and any negative value will be treated
// as default batch size
batchLimit = std::max(val_i, 0);
} else if (key == CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE) {
float val_f = 0.0f;
try {
val_f = std::stof(val);
} catch (const std::exception&) {
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
<< ". Expected only float numbers";
}
if (val_f < 0.f || val_f > 1.f) {
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE
<< ". Sparse rate must be in range [0.0f,1.0f]";
} else {
fcSparseWeiDecompressionRate = val_f;
}
} else if (key == PluginConfigParams::KEY_PERF_COUNT) {
if (val == PluginConfigParams::YES) collectPerfCounters = true;
else if (val == PluginConfigParams::NO) collectPerfCounters = false;

View File

@ -34,6 +34,7 @@ struct Config {
bool enableDynamicBatch = false;
std::string dumpToDot = "";
int batchLimit = 0;
float fcSparseWeiDecompressionRate = 1.0f;
size_t rtCacheCapacity = 5000ul;
InferenceEngine::IStreamsExecutor::Config streamExecutorConfig;
InferenceEngine::PerfHintsConfig perfHintsConfig;

View File

@ -26,6 +26,7 @@
#include <nodes/reorder.h>
#include "nodes/convert.h"
#include "nodes/subgraph.h"
#include "nodes/fullyconnected.h"
#include <ie_algorithm.hpp>
#include <blob_factory.hpp>
@ -340,6 +341,9 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
if (config.enforceBF16)
EnforceBF16();
if (config.fcSparseWeiDecompressionRate < 1.0f)
setMinSparseRate(config.fcSparseWeiDecompressionRate);
auto hasSubgraphConsumers = [] (const NodePtr& node) -> bool {
const auto & childEdges = node->getChildEdges();
return std::any_of(childEdges.begin(), childEdges.end(),
@ -1612,6 +1616,14 @@ void Graph::EnforceBF16() {
}
}
void Graph::setMinSparseRate(float minSparseRate) {
for (const auto &node : graphNodes) {
if (auto fcNodePtr = std::dynamic_pointer_cast<node::FullyConnected>(node)) {
fcNodePtr->setMinSparseRate(minSparseRate);
}
}
}
std::shared_ptr<ngraph::Function> Graph::dump() const {
return dump_graph_as_ie_ngraph_net(*this);
}

View File

@ -269,6 +269,7 @@ private:
std::unordered_map<Node*, size_t> syncNodesInds;
void EnforceBF16();
void setMinSparseRate(float minSparseRate);
};
} // namespace intel_cpu

View File

@ -418,6 +418,7 @@ std::string Node::getPrimitiveDescriptorType() {
SEARCH_TYPE(uni);
SEARCH_TYPE(winograd);
SEARCH_TYPE(sparse);
SEARCH_TYPE(_dw);
SEARCH_TYPE(_1x1);

View File

@ -4,6 +4,7 @@
#include "fullyconnected.h"
#include "eltwise.h"
#include "input.h"
#include "fake_quantize.h"
#include "input.h"
#include "reorder.h"
@ -22,6 +23,7 @@
#include <common/primitive_desc.hpp>
#include <common/primitive_desc_iface.hpp>
#include "onednn/dnnl.h"
#include "cpu/x64/cpu_isa_traits.hpp"
using namespace dnnl;
using namespace InferenceEngine;
@ -172,6 +174,8 @@ void FullyConnected::getSupportedDescriptors() {
if (getChildEdges().empty())
IE_THROW()<< errorPrefix << " has incorrect number of output edges";
useSparseWeights = useSparseWeightsDecompression();
auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID));
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID));
@ -360,6 +364,10 @@ void FullyConnected::prepareParams() {
}
// changed shapes may also cause the kernel type changed
selected_pd->setImplementationType(execPtr->getImplementationType());
// WA: We update implType to know whether weights decompression was used inside the kernel
if (selected_pd->getImplementationType() == ov::intel_cpu::brgemm_avx512_amx && useSparseWeights) {
selected_pd->setImplementationType(ov::intel_cpu::brgemm_sparse_avx512_amx);
}
// maybe expected 1x1 conv is not created, update the flag depends on the real type
useConv1x1 = execPtr->getImplementationType() == brgconv_avx512_1x1;
@ -503,6 +511,7 @@ bool FullyConnected::created() const {
const std::vector<impl_desc_type>& FullyConnected::getPrimitivesPriority() {
std::vector<impl_desc_type> priorities = {
impl_desc_type::unknown,
impl_desc_type::brgemm_sparse_avx512_amx,
impl_desc_type::brgemm_avx512_amx,
impl_desc_type::brgemm_avx512,
impl_desc_type::gemm_blas,
@ -578,9 +587,15 @@ void FullyConnected::createDescriptorInternal(const dnnl::memory::desc &inputDes
DnnlExtensionUtils::GetPlainFormatByRank(normalizedOutDims.size()));
}
dnnl::memory::desc wgh_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, dnnl::memory::format_tag::any);
// We need to explicitly specify the memory descriptor to use sparse weights decompression
dnnl::memory::desc wgh_candidate;
if (useSparseWeights) {
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, memory::desc::packed(nnzCount) };
} else {
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
wdt, dnnl::memory::format_tag::any };
}
if (withBiases) {
dnnl::memory::desc bias_candidate(DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(BIAS_ID).getStaticDims()), bdt,
dnnl::memory::format_tag::any);
@ -634,7 +649,7 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
portConfig.inPlace(-1);
portConfig.constant(false);
auto desc = getSrcMemDesc(itpd, i);
if (supportsUndefStridesAndOffset()) {
if (supportsUndefStridesAndOffset() && !(i == WEIGHTS_ID && useSparseWeights)) {
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
portConfig.setMemDesc(desc);
@ -868,6 +883,65 @@ MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {
return ptr;
}
bool FullyConnected::useSparseWeightsDecompression() {
// minSparseRate == 1 means that sparse feature is switched off
if (minSparseRate == 1.f) {
return false;
}
if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx))
return false;
auto weiDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims();
if (weiDims.size() != 2 || weiDims[0] % 64 != 0 || weiDims[1] % 64 != 0) {
return false;
}
auto inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);
auto weightsPrecision = getOriginalInputPrecisionAtPort(WEIGHTS_ID);
if (!one_of(inputPrecision , Precision::U8, Precision::I8) || weightsPrecision != Precision::I8) {
return false;
}
// calculate sparse rate
const auto constNode = std::dynamic_pointer_cast<Input>(getParentEdgeAt(WEIGHTS_ID)->getParent());
if (!constNode) {
return false;
}
auto blb = constNode->getMemoryPtr();
if (blb == nullptr)
IE_THROW() << "Cannot get const blob for node " << getName() << ".";
auto weightsData = reinterpret_cast<const int8_t*>(blb->GetPtr());
auto elementsCount = blb->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
size_t zerosCounts = 0;
for (int i = 0; i < elementsCount; i++) {
if (weightsData[i] == 0) {
zerosCounts++;
}
}
nnzCount = elementsCount - zerosCounts;
DEBUG_LOG(getName(), ", weightsData.size() = ", elementsCount, ", zerosCounts = ",
zerosCounts, ", nnzCount = ", nnzCount);
weiSparseRate = static_cast<float>(zerosCounts) / static_cast<float>(elementsCount);
// [av] WA: there is no point in using sparse decompression when the sparse rate is low
// todo: add heuristic
if (minSparseRate < 0.5)
minSparseRate = 0.5;
DEBUG_LOG(getName(), " | sparse rate = ", weiSparseRate * 100, "%, min sparse rate = ",
minSparseRate * 100, "%, use sparse weights = ", weiSparseRate >= minSparseRate);
if (weiSparseRate < minSparseRate) {
return false;
}
return true;
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -42,6 +42,7 @@ public:
void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
// void createPrimitive() override;
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
std::shared_ptr<MemoryDesc> getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override;
@ -58,6 +59,8 @@ public:
void setDynamicBatchLim(int lim) override;
void setMinSparseRate(float sparseRate) { minSparseRate = sparseRate; }
private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
const dnnl::memory::desc &outputDesc);
@ -106,6 +109,13 @@ private:
bool canBeExecutedInConv1x1() const;
MemoryPtr prepareWeightMemory(const DnnlMemoryDescPtr weightDesc);
// sparse weights
bool useSparseWeights = false;
int nnzCount = -1;
float minSparseRate = 1.f;
float weiSparseRate = 0.f;
bool useSparseWeightsDecompression();
};
} // namespace node

View File

@ -37,6 +37,7 @@ impl_desc_type parse_impl_name(std::string impl_desc_name) {
SEARCH_WORD(_1x1);
SEARCH_WORD(_dw);
SEARCH_WORD(reorder);
SEARCH_WORD(sparse);
if ((res & impl_desc_type::avx2) != impl_desc_type::avx2 &&
(res & impl_desc_type::avx512) != impl_desc_type::avx512)
SEARCH_WORD(avx);
@ -108,6 +109,7 @@ const char* impl_type_to_string(impl_desc_type type) {
CASE(brgemm_sse42);
CASE(brgemm_uni);
CASE(brgemm_avx512_amx);
CASE(brgemm_sparse_avx512_amx);
#undef CASE
return "unknown";

View File

@ -35,6 +35,8 @@ enum impl_desc_type {
reorder = 1<<22,
// winograd
winograd = 1<<23,
// sparse
sparse = 1<<24,
// real types
ref_any = ref | any,
@ -90,6 +92,7 @@ enum impl_desc_type {
brgemm_sse42 = brgemm | sse42,
brgemm_uni = brgemm | uni,
brgemm_avx512_amx = brgemm | avx512 | amx,
brgemm_sparse_avx512_amx = brgemm | sparse | avx512 | amx,
};
const char * impl_type_to_string(impl_desc_type type);

@ -1 +1 @@
Subproject commit dbe732c9102fa61f5eae58028215cfc571904baf
Subproject commit be8d5d21994bf8495a04ee58da3f0d566e695db5