[CPU] Enable scratchpad user mode for oneDNN based nodes (#12038)

This commit is contained in:
Tingqian Li 2022-11-04 16:10:26 +08:00 committed by GitHub
parent de687650d2
commit f34f35bfc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 132 additions and 14 deletions

View File

@ -148,5 +148,13 @@ std::shared_ptr<DnnlBlockedMemoryDesc> DnnlExtensionUtils::makeUndefinedDesc(con
}
}
DnnlMemoryDescPtr DnnlExtensionUtils::query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx) {
auto query = dnnl::convert_to_c(what);
const dnnl_memory_desc_t* cdesc = dnnl_primitive_desc_query_md(pd, query, idx);
if (!cdesc)
IE_THROW() << "query_md failed for query=" << query << " idx=" << idx << ".";
return DnnlExtensionUtils::makeDescriptor(*cdesc);
}
} // namespace intel_cpu
} // namespace ov

View File

@ -47,6 +47,8 @@ public:
static std::shared_ptr<DnnlBlockedMemoryDesc> makeUndefinedDesc(const dnnl::memory::desc &desc, const Shape& shape);
static size_t getMemSizeForDnnlDesc(const dnnl::memory::desc& desc);
static std::shared_ptr<DnnlMemoryDesc> query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx = 0);
};
} // namespace intel_cpu

View File

@ -0,0 +1,35 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include "common/memory.hpp"
#include "cpu_memory.h"
#include "dnnl_extension_utils.h"
namespace ov {
namespace intel_cpu {
class DnnlScratchPad {
DnnlMemoryMngrPtr mgrPtr;
dnnl::engine eng;
public:
DnnlScratchPad(dnnl::engine eng) : eng(eng) {
mgrPtr = std::make_shared<DnnlMemoryMngr>(std::unique_ptr<MemoryMngrWithReuse>(new MemoryMngrWithReuse()));
}
MemoryPtr createScratchPadMem(const MemoryDescPtr& md) {
auto mem = std::make_shared<Memory>(eng);
mem->Create(md, mgrPtr);
return mem;
}
};
using DnnlScratchPadPtr = std::shared_ptr<DnnlScratchPad>;
} // namespace intel_cpu
} // namespace ov

View File

@ -79,6 +79,7 @@ void Graph::CreateGraph(NET &net, const ExtensionManager::Ptr& extMgr,
rtParamsCache = std::make_shared<MultiCache>(config.rtCacheCapacity);
sharedMutex = mutex;
rtScratchPad = std::make_shared<DnnlScratchPad>(getEngine());
Replicate(net, extMgr);
InitGraph();
@ -98,6 +99,7 @@ void Graph::CreateGraph(const std::vector<NodePtr> &graphNodes,
weightsCache = config.streamExecutorConfig._streams != 1 ? w_cache : nullptr;
rtParamsCache = std::make_shared<MultiCache>(config.rtCacheCapacity);
rtScratchPad = std::make_shared<DnnlScratchPad>(getEngine());
this->_name = std::move(name);
this->reuse_io_tensors = false;
@ -158,6 +160,7 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph, const Ex
node->setRuntimeCache(rtParamsCache);
node->setSharedMutex(sharedMutex);
node->setRuntimeScratchPad(rtScratchPad);
graphNodes.push_back(node);
@ -272,6 +275,7 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
node->setRuntimeCache(rtParamsCache);
node->setSharedMutex(sharedMutex);
node->setRuntimeScratchPad(rtScratchPad);
graphNodes.push_back(node);
@ -1357,6 +1361,7 @@ bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPo
node->setQuantizedGraphFlag(true);
}
node->setRuntimeCache(rtParamsCache);
node->setRuntimeScratchPad(rtScratchPad);
if (initNode) {
node->getSupportedDescriptors();

View File

@ -11,6 +11,7 @@
#include "node.h"
#include "edge.h"
#include "cache/multi_cache.h"
#include "dnnl_scratch_pad.h"
#include <map>
#include <string>
#include <vector>
@ -264,6 +265,7 @@ private:
MultiCachePtr rtParamsCache;
std::shared_ptr<std::mutex> sharedMutex = nullptr;
DnnlScratchPadPtr rtScratchPad;
void EnforceBF16();
};

View File

@ -20,6 +20,7 @@
#include "extension_mngr.h"
#include "primitive.h"
#include "weights_cache.hpp"
#include "dnnl_scratch_pad.h"
#include <openvino/itt.hpp>
#include "utils/ngraph_utils.hpp"
#include <ngraph/ops.hpp>
@ -28,6 +29,7 @@
#include <nodes/common/blocked_desc_creator.h>
#include "cpu_types.h"
#include "cpu_shape.h"
#include "config.h"
#include "nodes/node_config.h"
#include "cache/multi_cache.h"
@ -573,6 +575,10 @@ public:
rtParamsCache = cache;
}
void setRuntimeScratchPad(DnnlScratchPadPtr scratchPad) {
rtScratchPad = scratchPad;
}
void setSharedMutex(const std::shared_ptr<std::mutex>& mutex) {
sharedMutex = mutex;
}
@ -747,6 +753,16 @@ protected:
return rtParamsCache;
}
DnnlScratchPadPtr getRuntimeScratchPad() const {
return rtScratchPad;
}
MemoryPtr getScratchPadMem(const const_dnnl_primitive_desc_t& pd) {
auto scratchpadMemoryDesc = DnnlExtensionUtils::query_md(pd, dnnl::query::scratchpad_md);
scratchpadMem = getRuntimeScratchPad()->createScratchPadMem(scratchpadMemoryDesc);
return scratchpadMem;
}
std::vector<VectorDims> lastInputDims = {};
std::shared_ptr<IShapeInfer> shapeInference;
@ -775,6 +791,8 @@ private:
PerfCounters profiling;
MultiCachePtr rtParamsCache;
DnnlScratchPadPtr rtScratchPad;
MemoryPtr scratchpadMem;
bool isEdgesEmpty(const std::vector<EdgeWeakPtr>& edges) const;

View File

@ -50,5 +50,9 @@ bool DnnlExecutor::needReordering() const {
return !inputReorders.empty() || !outputReorders.empty();
}
} // namespace intel_cpu
Primitive DnnlExecutor::getExecPrim() const {
return execPrim;
}
} // namespace intel_cpu
} // namespace ov

View File

@ -29,6 +29,7 @@ class DnnlExecutor {
void exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
bool needReordering() const;
virtual ~DnnlExecutor() = default;
Primitive getExecPrim() const;
protected:
DnnlExecutor() = default;

View File

@ -1316,6 +1316,7 @@ void Convolution::prepareParams() {
else
addZeroPoints(attr);
setPostOps(attr, outMemoryDesc->getShape().getStaticDims(), preferLegacyPostOps, true);
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
return std::make_shared<dnnl::primitive_attr>(std::move(attr));
};
@ -1460,6 +1461,10 @@ void Convolution::prepareParams() {
appendZeroPointsArgs();
Node::appendPostOpArgs(*pAttrLocal, primArgs, convPostOpsArgs[preferLegacyPostOps]);
auto pd = (*(execPtr->getExecPrim())).get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
} else {
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
}
@ -1489,6 +1494,7 @@ void Convolution::execute(dnnl::stream strm) {
if (!execPtr) {
IE_THROW() << "Can't execute Convolution node with name: " << getName() << ", because executor is not compiled";
}
execPtr->exec(primArgs, strm);
}

View File

@ -562,6 +562,7 @@ void Deconvolution::execute(dnnl::stream strm) {
if (!execPtr) {
IE_THROW() << "Can't execute Deconvolution node with name: " << getName() << ", because executor is not compiled";
}
execPtr->exec(primArgs, strm);
if (externOutShape) {
@ -752,6 +753,7 @@ void Deconvolution::prepareParams() {
} else {
pAttrLocal = makePrimitiveAttr(dstMemPtr->getStaticDims());
}
(*pAttrLocal).set_scratchpad_mode(dnnl::scratchpad_mode::user);
DnnlMemoryDescCPtr wghDesc;
if (isInt8) {
@ -872,6 +874,10 @@ void Deconvolution::prepareParams() {
primArgs[DNNL_ARG_DIFF_SRC] = dstMemPtr->GetPrimitive();
}
Node::appendPostOpArgs(*pAttrLocal, primArgs, postOpsArgs);
auto pd = (*(execPtr->getExecPrim())).get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
} else {
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
}

View File

@ -226,6 +226,7 @@ void FullyConnected::prepareParams() {
AttrPtr attr = std::make_shared<dnnl::primitive_attr>();
setPostOps(*attr, dstMemPtr->getStaticDims());
(*attr).set_scratchpad_mode(dnnl::scratchpad_mode::user);
DnnlMemoryDescCPtr weightDesc = wghMemPtr->GetDescWithType<DnnlMemoryDesc>();
DnnlMemoryDescCPtr biasDesc = nullptr;
@ -311,6 +312,10 @@ void FullyConnected::prepareParams() {
appendPostOpArgs(*attr, primArgs, postOpsArgs);
auto pd = (*prim).get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
auto reshapeMemory = [this](int argType) {
auto param = primArgs.find(argType);
if (param != primArgs.end()) {

View File

@ -178,8 +178,10 @@ void Lrn::prepareParams() {
DnnlDesriptor desc(std::shared_ptr<dnnl::lrn_forward::desc>(
new dnnl::lrn_forward::desc(dnnl::prop_kind::forward_scoring, key.alg, key.inp0->getDnnlDesc(), key.size, key.alpha, key.beta, key.k)));
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
dnnl::lrn_forward::primitive_desc prim_desc;
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine);
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine, attr);
while (static_cast<bool>(itpd)) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
if (impl_type == key.implType) {
@ -199,9 +201,12 @@ void Lrn::prepareParams() {
}
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
auto src = srcMemPtr->GetPrimitive();
auto dst = dstMemPtr->GetPrimitive();
primArgs = { {DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst} };
primArgs = { {DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, {DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()} };
}
bool Lrn::created() const {

View File

@ -196,6 +196,8 @@ Node::AttrPtr MatMul::initPrimitiveAttr(const VectorDims &dims) {
setPostOps(*attr, dims, true);
(*attr).set_scratchpad_mode(dnnl::scratchpad_mode::user);
return attr;
}
@ -569,6 +571,10 @@ void MatMul::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
primArgs[DNNL_ARG_SRC_0] = src0MemPtr->GetPrimitive();
primArgs[DNNL_ARG_WEIGHTS_0] = src1MemPtr->GetPrimitive();
primArgs[DNNL_ARG_DST] = dstMemPtr->GetPrimitive();
@ -579,7 +585,7 @@ void MatMul::prepareParams() {
}
void MatMul::executeDynamicImpl(dnnl::stream strm) {
Node::execute(strm);
execute(strm);
}
const std::vector<impl_desc_type>& MatMul::getPrimitivesPriority() {
@ -618,7 +624,6 @@ const std::vector<impl_desc_type>& MatMul::getPrimitivesPriority() {
}
return implPriorities;
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -387,9 +387,11 @@ void Pooling::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}};
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, {DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()}};
Node::appendPostOpArgs(*attr, primArgs, postOpsArgs);
}
@ -616,6 +618,8 @@ Node::AttrPtr Pooling::initPrimitiveAttr() {
setPostOps(*attr);
(*attr).set_scratchpad_mode(dnnl::scratchpad_mode::user);
return attr;
}
@ -635,6 +639,6 @@ void Pooling::setPostOps(dnnl::primitive_attr &attr) {
attr.set_post_ops(ops);
}
} // namespace node
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -825,19 +825,21 @@ void RNN::prepareParams() {
auto builder = [this](const RNNKey& key) -> std::shared_ptr<dnnl::primitive> {
fillDescs();
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
if (key.cellType == dnnl::algorithm::vanilla_rnn) {
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, getEngine()));
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::vanilla_gru) {
std::shared_ptr<gru_forward::desc> desc = descs[0];
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, getEngine()));
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::lbr_gru) {
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, getEngine()));
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::vanilla_lstm) {
std::shared_ptr<lstm_forward::desc> desc = descs[0];
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, getEngine()));
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, attr, getEngine()));
} else {
return nullptr;
}
@ -852,6 +854,9 @@ void RNN::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
scratchpadMem = getScratchPadMem(pd);
if (!wasMemoryPrepared || wFormatWasChanged) {
auto pd = (*prim).get_primitive_desc();
auto query_weights_md = [&](int idx = 0) -> dnnl::memory::desc {
@ -896,6 +901,7 @@ void RNN::execute(dnnl::stream strm) {
{DNNL_ARG_WEIGHTS_ITER, wgh_stat_mem->GetPrimitive()},
{DNNL_ARG_BIAS, wgh_bias_mem->GetPrimitive()},
{DNNL_ARG_DST_LAYER, dst_data_mem->GetPrimitive()},
{DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()}
};
int state_i_tags[] {DNNL_ARG_SRC_ITER, DNNL_ARG_SRC_ITER_C};

View File

@ -120,6 +120,7 @@ private:
static constexpr size_t batchDimDummyValue = 64lu;
bool wasMemoryPrepared = false;
MemoryPtr scratchpadMem;
};
} // namespace node

View File

@ -152,7 +152,9 @@ void SoftMax::prepareParams() {
softmax_forward::primitive_desc prim_desc;
DnnlDesriptor desc(std::shared_ptr<softmax_forward::desc>(
new softmax_forward::desc(prop_kind::forward_scoring, key.inp0->getDnnlDesc(), key.axis)));
primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine);
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine, attr);
while (itpd) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
@ -180,9 +182,12 @@ void SoftMax::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}};
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, {DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()}};
}
void SoftMax::executeDynamicImpl(dnnl::stream strm) {

@ -1 +1 @@
Subproject commit fac2cfd709aa87f0c74b3bcfd3461e325cbc28d8
Subproject commit 6df930dab5ab0a7dfaea6100acd03b479e2fa0a8