[CPU] Enable scratchpad user mode for oneDNN based nodes (#12038)
This commit is contained in:
parent
de687650d2
commit
f34f35bfc7
@ -148,5 +148,13 @@ std::shared_ptr<DnnlBlockedMemoryDesc> DnnlExtensionUtils::makeUndefinedDesc(con
|
||||
}
|
||||
}
|
||||
|
||||
DnnlMemoryDescPtr DnnlExtensionUtils::query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx) {
|
||||
auto query = dnnl::convert_to_c(what);
|
||||
const dnnl_memory_desc_t* cdesc = dnnl_primitive_desc_query_md(pd, query, idx);
|
||||
if (!cdesc)
|
||||
IE_THROW() << "query_md failed for query=" << query << " idx=" << idx << ".";
|
||||
return DnnlExtensionUtils::makeDescriptor(*cdesc);
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -47,6 +47,8 @@ public:
|
||||
|
||||
static std::shared_ptr<DnnlBlockedMemoryDesc> makeUndefinedDesc(const dnnl::memory::desc &desc, const Shape& shape);
|
||||
static size_t getMemSizeForDnnlDesc(const dnnl::memory::desc& desc);
|
||||
|
||||
static std::shared_ptr<DnnlMemoryDesc> query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx = 0);
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
35
src/plugins/intel_cpu/src/dnnl_scratch_pad.h
Normal file
35
src/plugins/intel_cpu/src/dnnl_scratch_pad.h
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "common/memory.hpp"
|
||||
#include "cpu_memory.h"
|
||||
#include "dnnl_extension_utils.h"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
class DnnlScratchPad {
|
||||
DnnlMemoryMngrPtr mgrPtr;
|
||||
dnnl::engine eng;
|
||||
|
||||
public:
|
||||
DnnlScratchPad(dnnl::engine eng) : eng(eng) {
|
||||
mgrPtr = std::make_shared<DnnlMemoryMngr>(std::unique_ptr<MemoryMngrWithReuse>(new MemoryMngrWithReuse()));
|
||||
}
|
||||
|
||||
MemoryPtr createScratchPadMem(const MemoryDescPtr& md) {
|
||||
auto mem = std::make_shared<Memory>(eng);
|
||||
mem->Create(md, mgrPtr);
|
||||
return mem;
|
||||
}
|
||||
};
|
||||
|
||||
using DnnlScratchPadPtr = std::shared_ptr<DnnlScratchPad>;
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -79,6 +79,7 @@ void Graph::CreateGraph(NET &net, const ExtensionManager::Ptr& extMgr,
|
||||
|
||||
rtParamsCache = std::make_shared<MultiCache>(config.rtCacheCapacity);
|
||||
sharedMutex = mutex;
|
||||
rtScratchPad = std::make_shared<DnnlScratchPad>(getEngine());
|
||||
|
||||
Replicate(net, extMgr);
|
||||
InitGraph();
|
||||
@ -98,6 +99,7 @@ void Graph::CreateGraph(const std::vector<NodePtr> &graphNodes,
|
||||
weightsCache = config.streamExecutorConfig._streams != 1 ? w_cache : nullptr;
|
||||
|
||||
rtParamsCache = std::make_shared<MultiCache>(config.rtCacheCapacity);
|
||||
rtScratchPad = std::make_shared<DnnlScratchPad>(getEngine());
|
||||
|
||||
this->_name = std::move(name);
|
||||
this->reuse_io_tensors = false;
|
||||
@ -158,6 +160,7 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &subgraph, const Ex
|
||||
|
||||
node->setRuntimeCache(rtParamsCache);
|
||||
node->setSharedMutex(sharedMutex);
|
||||
node->setRuntimeScratchPad(rtScratchPad);
|
||||
|
||||
graphNodes.push_back(node);
|
||||
|
||||
@ -272,6 +275,7 @@ void Graph::Replicate(const CNNNetwork &network, const ExtensionManager::Ptr& ex
|
||||
|
||||
node->setRuntimeCache(rtParamsCache);
|
||||
node->setSharedMutex(sharedMutex);
|
||||
node->setRuntimeScratchPad(rtScratchPad);
|
||||
|
||||
graphNodes.push_back(node);
|
||||
|
||||
@ -1357,6 +1361,7 @@ bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPo
|
||||
node->setQuantizedGraphFlag(true);
|
||||
}
|
||||
node->setRuntimeCache(rtParamsCache);
|
||||
node->setRuntimeScratchPad(rtScratchPad);
|
||||
|
||||
if (initNode) {
|
||||
node->getSupportedDescriptors();
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "node.h"
|
||||
#include "edge.h"
|
||||
#include "cache/multi_cache.h"
|
||||
#include "dnnl_scratch_pad.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -264,6 +265,7 @@ private:
|
||||
|
||||
MultiCachePtr rtParamsCache;
|
||||
std::shared_ptr<std::mutex> sharedMutex = nullptr;
|
||||
DnnlScratchPadPtr rtScratchPad;
|
||||
|
||||
void EnforceBF16();
|
||||
};
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "extension_mngr.h"
|
||||
#include "primitive.h"
|
||||
#include "weights_cache.hpp"
|
||||
#include "dnnl_scratch_pad.h"
|
||||
#include <openvino/itt.hpp>
|
||||
#include "utils/ngraph_utils.hpp"
|
||||
#include <ngraph/ops.hpp>
|
||||
@ -28,6 +29,7 @@
|
||||
#include <nodes/common/blocked_desc_creator.h>
|
||||
#include "cpu_types.h"
|
||||
#include "cpu_shape.h"
|
||||
#include "config.h"
|
||||
#include "nodes/node_config.h"
|
||||
#include "cache/multi_cache.h"
|
||||
|
||||
@ -573,6 +575,10 @@ public:
|
||||
rtParamsCache = cache;
|
||||
}
|
||||
|
||||
void setRuntimeScratchPad(DnnlScratchPadPtr scratchPad) {
|
||||
rtScratchPad = scratchPad;
|
||||
}
|
||||
|
||||
void setSharedMutex(const std::shared_ptr<std::mutex>& mutex) {
|
||||
sharedMutex = mutex;
|
||||
}
|
||||
@ -747,6 +753,16 @@ protected:
|
||||
return rtParamsCache;
|
||||
}
|
||||
|
||||
DnnlScratchPadPtr getRuntimeScratchPad() const {
|
||||
return rtScratchPad;
|
||||
}
|
||||
|
||||
MemoryPtr getScratchPadMem(const const_dnnl_primitive_desc_t& pd) {
|
||||
auto scratchpadMemoryDesc = DnnlExtensionUtils::query_md(pd, dnnl::query::scratchpad_md);
|
||||
scratchpadMem = getRuntimeScratchPad()->createScratchPadMem(scratchpadMemoryDesc);
|
||||
return scratchpadMem;
|
||||
}
|
||||
|
||||
std::vector<VectorDims> lastInputDims = {};
|
||||
|
||||
std::shared_ptr<IShapeInfer> shapeInference;
|
||||
@ -775,6 +791,8 @@ private:
|
||||
PerfCounters profiling;
|
||||
|
||||
MultiCachePtr rtParamsCache;
|
||||
DnnlScratchPadPtr rtScratchPad;
|
||||
MemoryPtr scratchpadMem;
|
||||
|
||||
bool isEdgesEmpty(const std::vector<EdgeWeakPtr>& edges) const;
|
||||
|
||||
|
@ -50,5 +50,9 @@ bool DnnlExecutor::needReordering() const {
|
||||
return !inputReorders.empty() || !outputReorders.empty();
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
Primitive DnnlExecutor::getExecPrim() const {
|
||||
return execPrim;
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -29,6 +29,7 @@ class DnnlExecutor {
|
||||
void exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
|
||||
bool needReordering() const;
|
||||
virtual ~DnnlExecutor() = default;
|
||||
Primitive getExecPrim() const;
|
||||
|
||||
protected:
|
||||
DnnlExecutor() = default;
|
||||
|
@ -1316,6 +1316,7 @@ void Convolution::prepareParams() {
|
||||
else
|
||||
addZeroPoints(attr);
|
||||
setPostOps(attr, outMemoryDesc->getShape().getStaticDims(), preferLegacyPostOps, true);
|
||||
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
return std::make_shared<dnnl::primitive_attr>(std::move(attr));
|
||||
};
|
||||
@ -1460,6 +1461,10 @@ void Convolution::prepareParams() {
|
||||
appendZeroPointsArgs();
|
||||
|
||||
Node::appendPostOpArgs(*pAttrLocal, primArgs, convPostOpsArgs[preferLegacyPostOps]);
|
||||
|
||||
auto pd = (*(execPtr->getExecPrim())).get_primitive_desc();
|
||||
auto scratchpadMem = getScratchPadMem(pd);
|
||||
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
|
||||
} else {
|
||||
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||
}
|
||||
@ -1489,6 +1494,7 @@ void Convolution::execute(dnnl::stream strm) {
|
||||
if (!execPtr) {
|
||||
IE_THROW() << "Can't execute Convolution node with name: " << getName() << ", because executor is not compiled";
|
||||
}
|
||||
|
||||
execPtr->exec(primArgs, strm);
|
||||
}
|
||||
|
||||
|
@ -562,6 +562,7 @@ void Deconvolution::execute(dnnl::stream strm) {
|
||||
if (!execPtr) {
|
||||
IE_THROW() << "Can't execute Deconvolution node with name: " << getName() << ", because executor is not compiled";
|
||||
}
|
||||
|
||||
execPtr->exec(primArgs, strm);
|
||||
|
||||
if (externOutShape) {
|
||||
@ -752,6 +753,7 @@ void Deconvolution::prepareParams() {
|
||||
} else {
|
||||
pAttrLocal = makePrimitiveAttr(dstMemPtr->getStaticDims());
|
||||
}
|
||||
(*pAttrLocal).set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
DnnlMemoryDescCPtr wghDesc;
|
||||
if (isInt8) {
|
||||
@ -872,6 +874,10 @@ void Deconvolution::prepareParams() {
|
||||
primArgs[DNNL_ARG_DIFF_SRC] = dstMemPtr->GetPrimitive();
|
||||
}
|
||||
Node::appendPostOpArgs(*pAttrLocal, primArgs, postOpsArgs);
|
||||
|
||||
auto pd = (*(execPtr->getExecPrim())).get_primitive_desc();
|
||||
auto scratchpadMem = getScratchPadMem(pd);
|
||||
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
|
||||
} else {
|
||||
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||
}
|
||||
|
@ -226,6 +226,7 @@ void FullyConnected::prepareParams() {
|
||||
|
||||
AttrPtr attr = std::make_shared<dnnl::primitive_attr>();
|
||||
setPostOps(*attr, dstMemPtr->getStaticDims());
|
||||
(*attr).set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
DnnlMemoryDescCPtr weightDesc = wghMemPtr->GetDescWithType<DnnlMemoryDesc>();
|
||||
DnnlMemoryDescCPtr biasDesc = nullptr;
|
||||
@ -311,6 +312,10 @@ void FullyConnected::prepareParams() {
|
||||
|
||||
appendPostOpArgs(*attr, primArgs, postOpsArgs);
|
||||
|
||||
auto pd = (*prim).get_primitive_desc();
|
||||
auto scratchpadMem = getScratchPadMem(pd);
|
||||
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
|
||||
|
||||
auto reshapeMemory = [this](int argType) {
|
||||
auto param = primArgs.find(argType);
|
||||
if (param != primArgs.end()) {
|
||||
|
@ -178,8 +178,10 @@ void Lrn::prepareParams() {
|
||||
DnnlDesriptor desc(std::shared_ptr<dnnl::lrn_forward::desc>(
|
||||
new dnnl::lrn_forward::desc(dnnl::prop_kind::forward_scoring, key.alg, key.inp0->getDnnlDesc(), key.size, key.alpha, key.beta, key.k)));
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
dnnl::lrn_forward::primitive_desc prim_desc;
|
||||
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine);
|
||||
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine, attr);
|
||||
while (static_cast<bool>(itpd)) {
|
||||
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
|
||||
if (impl_type == key.implType) {
|
||||
@ -199,9 +201,12 @@ void Lrn::prepareParams() {
|
||||
}
|
||||
prim = result.first;
|
||||
|
||||
auto pd = (*prim).get_primitive_desc();
|
||||
auto scratchpadMem = getScratchPadMem(pd);
|
||||
|
||||
auto src = srcMemPtr->GetPrimitive();
|
||||
auto dst = dstMemPtr->GetPrimitive();
|
||||
primArgs = { {DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst} };
|
||||
primArgs = { {DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, {DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()} };
|
||||
}
|
||||
|
||||
bool Lrn::created() const {
|
||||
|
@ -196,6 +196,8 @@ Node::AttrPtr MatMul::initPrimitiveAttr(const VectorDims &dims) {
|
||||
|
||||
setPostOps(*attr, dims, true);
|
||||
|
||||
(*attr).set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
return attr;
|
||||
}
|
||||
|
||||
@ -569,6 +571,10 @@ void MatMul::prepareParams() {
|
||||
|
||||
prim = result.first;
|
||||
|
||||
auto pd = (*prim).get_primitive_desc();
|
||||
auto scratchpadMem = getScratchPadMem(pd);
|
||||
|
||||
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
|
||||
primArgs[DNNL_ARG_SRC_0] = src0MemPtr->GetPrimitive();
|
||||
primArgs[DNNL_ARG_WEIGHTS_0] = src1MemPtr->GetPrimitive();
|
||||
primArgs[DNNL_ARG_DST] = dstMemPtr->GetPrimitive();
|
||||
@ -579,7 +585,7 @@ void MatMul::prepareParams() {
|
||||
}
|
||||
|
||||
void MatMul::executeDynamicImpl(dnnl::stream strm) {
|
||||
Node::execute(strm);
|
||||
execute(strm);
|
||||
}
|
||||
|
||||
const std::vector<impl_desc_type>& MatMul::getPrimitivesPriority() {
|
||||
@ -618,7 +624,6 @@ const std::vector<impl_desc_type>& MatMul::getPrimitivesPriority() {
|
||||
}
|
||||
return implPriorities;
|
||||
}
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -387,9 +387,11 @@ void Pooling::prepareParams() {
|
||||
|
||||
prim = result.first;
|
||||
|
||||
auto pd = (*prim).get_primitive_desc();
|
||||
auto scratchpadMem = getScratchPadMem(pd);
|
||||
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
|
||||
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
|
||||
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}};
|
||||
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, {DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()}};
|
||||
|
||||
Node::appendPostOpArgs(*attr, primArgs, postOpsArgs);
|
||||
}
|
||||
@ -616,6 +618,8 @@ Node::AttrPtr Pooling::initPrimitiveAttr() {
|
||||
|
||||
setPostOps(*attr);
|
||||
|
||||
(*attr).set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
return attr;
|
||||
}
|
||||
|
||||
@ -635,6 +639,6 @@ void Pooling::setPostOps(dnnl::primitive_attr &attr) {
|
||||
attr.set_post_ops(ops);
|
||||
}
|
||||
|
||||
} // namespace node
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -825,19 +825,21 @@ void RNN::prepareParams() {
|
||||
|
||||
auto builder = [this](const RNNKey& key) -> std::shared_ptr<dnnl::primitive> {
|
||||
fillDescs();
|
||||
dnnl::primitive_attr attr;
|
||||
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
if (key.cellType == dnnl::algorithm::vanilla_rnn) {
|
||||
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
|
||||
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, getEngine()));
|
||||
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::vanilla_gru) {
|
||||
std::shared_ptr<gru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, getEngine()));
|
||||
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::lbr_gru) {
|
||||
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, getEngine()));
|
||||
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::vanilla_lstm) {
|
||||
std::shared_ptr<lstm_forward::desc> desc = descs[0];
|
||||
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, getEngine()));
|
||||
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
@ -852,6 +854,9 @@ void RNN::prepareParams() {
|
||||
|
||||
prim = result.first;
|
||||
|
||||
auto pd = (*prim).get_primitive_desc();
|
||||
scratchpadMem = getScratchPadMem(pd);
|
||||
|
||||
if (!wasMemoryPrepared || wFormatWasChanged) {
|
||||
auto pd = (*prim).get_primitive_desc();
|
||||
auto query_weights_md = [&](int idx = 0) -> dnnl::memory::desc {
|
||||
@ -896,6 +901,7 @@ void RNN::execute(dnnl::stream strm) {
|
||||
{DNNL_ARG_WEIGHTS_ITER, wgh_stat_mem->GetPrimitive()},
|
||||
{DNNL_ARG_BIAS, wgh_bias_mem->GetPrimitive()},
|
||||
{DNNL_ARG_DST_LAYER, dst_data_mem->GetPrimitive()},
|
||||
{DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()}
|
||||
};
|
||||
|
||||
int state_i_tags[] {DNNL_ARG_SRC_ITER, DNNL_ARG_SRC_ITER_C};
|
||||
|
@ -120,6 +120,7 @@ private:
|
||||
static constexpr size_t batchDimDummyValue = 64lu;
|
||||
|
||||
bool wasMemoryPrepared = false;
|
||||
MemoryPtr scratchpadMem;
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
|
@ -152,7 +152,9 @@ void SoftMax::prepareParams() {
|
||||
softmax_forward::primitive_desc prim_desc;
|
||||
DnnlDesriptor desc(std::shared_ptr<softmax_forward::desc>(
|
||||
new softmax_forward::desc(prop_kind::forward_scoring, key.inp0->getDnnlDesc(), key.axis)));
|
||||
primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine);
|
||||
dnnl::primitive_attr attr;
|
||||
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine, attr);
|
||||
|
||||
while (itpd) {
|
||||
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
|
||||
@ -180,9 +182,12 @@ void SoftMax::prepareParams() {
|
||||
|
||||
prim = result.first;
|
||||
|
||||
auto pd = (*prim).get_primitive_desc();
|
||||
auto scratchpadMem = getScratchPadMem(pd);
|
||||
|
||||
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
|
||||
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
|
||||
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}};
|
||||
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, {DNNL_ARG_SCRATCHPAD, scratchpadMem->GetPrimitive()}};
|
||||
}
|
||||
|
||||
void SoftMax::executeDynamicImpl(dnnl::stream strm) {
|
||||
|
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
@ -1 +1 @@
|
||||
Subproject commit fac2cfd709aa87f0c74b3bcfd3461e325cbc28d8
|
||||
Subproject commit 6df930dab5ab0a7dfaea6100acd03b479e2fa0a8
|
Loading…
Reference in New Issue
Block a user