[CPU] Optimize reorder inside tensor_iterator (#14689)

This commit is contained in:
Tingqian Li 2023-02-13 18:26:40 +08:00 committed by GitHub
parent b8a7b3bb43
commit b9a1b45a82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 209 additions and 364 deletions

View File

@ -52,14 +52,14 @@ void Memory::Create(const dnnl::memory::desc& desc, const void *data, bool pads_
// ========================
// Equivalent of constructor memory(const primitive_desc &desc, void *hdl)
// but with ability to skipp pads zeroing.
prim.reset(new memory(desc, eng, DNNL_MEMORY_NONE));
prim = memory(desc, eng, DNNL_MEMORY_NONE);
//
// ========================
if (data != nullptr) {
if (pads_zeroing)
prim->set_data_handle(const_cast<void*>(data));
prim.set_data_handle(const_cast<void*>(data));
else
prim->set_data_handle_no_pads_proc(const_cast<void*>(data));
prim.set_data_handle_no_pads_proc(const_cast<void*>(data));
}
}
@ -97,13 +97,13 @@ void Memory::SetData(const Memory& src, bool ftz) const {
if (ftz
&& src.GetDataType() == memory::data_type::f32
&& prim->get_desc().data.format_kind != dnnl_format_kind_wino
&& prim.get_desc().data.format_kind != dnnl_format_kind_wino
// WA: to avoid zero filling auxiliary information
&& prim->get_desc().data.format_kind != dnnl_format_kind_rnn_packed
&& prim.get_desc().data.format_kind != dnnl_format_kind_rnn_packed
&& GetDataType() != memory::data_type::bf16) {
// Internal blobs haven't strides yet.
auto *memData = static_cast<float *>(GetData());
memData += prim->get_desc().data.offset0;
memData += prim.get_desc().data.offset0;
setSubnormalsToZero(memData, GetSize() / sizeof(float));
}
}
@ -116,7 +116,7 @@ void Memory::FillZero() {
void *Memory::GetPtr() const {
auto ptr = static_cast<uint8_t*>(GetData());
const dnnl_memory_desc_t md = prim->get_desc().data;
const dnnl_memory_desc_t md = prim.get_desc().data;
dnnl::impl::memory_desc_wrapper wrapper(md);
ptr += wrapper.offset0() * wrapper.data_type_size();
return ptr;
@ -144,12 +144,12 @@ void Memory::setDataHandle(void *data) {
size_t maxMemSize = pMemDesc->hasDefinedMaxSize() ? pMemDesc->getMaxMemSize() : 0;
mgrHandle->setExtBuff(data, maxMemSize);
prim->set_data_handle(mgrHandle->getRawPtr()); // for pads zeroing, to preserve dnnl::memory::set_data_handle behaviour
prim.set_data_handle(mgrHandle->getRawPtr()); // for pads zeroing, to preserve dnnl::memory::set_data_handle behaviour
}
void Memory::update() {
if (isAllocated()) {
prim->set_data_handle_no_pads_proc(mgrHandle->getRawPtr());
prim.set_data_handle_no_pads_proc(mgrHandle->getRawPtr());
}
}

View File

@ -166,14 +166,14 @@ public:
dnnl::memory GetPrimitive() const {
if (isAllocated()) {
return *prim;
return prim;
} else {
IE_THROW() << "Can not perform GetPrimitive call to the not allocated memory";
}
}
bool isAllocated() const noexcept {
return prim != nullptr;
return static_cast<bool>(prim);
}
/**
@ -263,7 +263,7 @@ private:
private:
MemoryDescPtr pMemDesc;
std::shared_ptr<dnnl::memory> prim;
dnnl::memory prim;
dnnl::engine eng;
DnnlMemMngrHandle mgrHandle;
};

View File

@ -880,7 +880,7 @@ void Graph::CreatePrimitives() {
node->createPrimitive();
#ifdef CPU_DEBUG_CAPS
if (node->prim) {
auto pd_c = (*node->prim).get_primitive_desc();
auto pd_c = node->prim.get_primitive_desc();
auto* pd = reinterpret_cast<const dnnl_primitive_desc*>(pd_c);
DEBUG_LOG("verbose##", node->getName(), "##", pd->info(), "\n");
}

View File

@ -533,7 +533,7 @@ std::vector<memory::format_tag> Node::getAvailableFormatsForDims(const Shape &di
void Node::execute(dnnl::stream strm) {
if (prim) {
(*prim).execute(strm, primArgs);
prim.execute(strm, primArgs);
}
}
@ -555,7 +555,7 @@ void Node::updateDynamicParams() {
prepareParams();
#ifdef CPU_DEBUG_CAPS
if (prim) {
auto pd_c = (*prim).get_primitive_desc();
auto pd_c = prim.get_primitive_desc();
auto* pd = reinterpret_cast<const dnnl_primitive_desc*>(pd_c);
DEBUG_LOG("verbose##", getName(), "##", pd->info(), "\n");
}

View File

@ -18,7 +18,6 @@
#include "onednn/dnnl.h"
#include "onednn/iml_type_mapper.h"
#include "extension_mngr.h"
#include "primitive.h"
#include "weights_cache.hpp"
#include "dnnl_scratch_pad.h"
#include <openvino/itt.hpp>
@ -371,54 +370,6 @@ public:
*/
virtual void init() {}
template <class PD, class D, typename FPD = bool>
PD createPrimitiveDescriptor(const dnnl::primitive_attr &attr = dnnl::primitive_attr()) {
auto descsCompatible = [](const std::vector<MemoryDescPtr>& srcDescs,
const std::vector<PortConfig>& selectedDescs) {
if (srcDescs.empty() && selectedDescs.empty())
return true;
if (srcDescs.empty() || selectedDescs.empty())
return false;
for (size_t i = 0; i < srcDescs.size() && i < selectedDescs.size(); i++) {
if (!srcDescs[i]->isCompatible(*selectedDescs[i].getMemDesc()))
return false;
}
return true;
};
const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor();
if (selected_pd == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
for (const auto& desc : descs) {
auto itpd = desc.createPrimitiveDescriptorIterator(engine, attr);
while (static_cast<bool>(itpd)) {
std::vector<MemoryDescPtr> srcDescs;
for (size_t i = 0; i < descInputNumbers(desc); i++)
srcDescs.push_back(getSrcMemDesc(itpd, i));
std::vector<MemoryDescPtr> dstDescs;
for (size_t i = 0; i < descOutputNumbers(desc); i++)
dstDescs.push_back(getDstMemDesc(itpd, i));
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
if (impl_type == selected_pd->getImplementationType() &&
descsCompatible(srcDescs, selected_pd->getConfig().inConfs) &&
descsCompatible(dstDescs, selected_pd->getConfig().outConfs)) {
prepareMemory(itpd);
PD prim_desc = createPd<PD, D, FPD>(desc);
return {itpd.get()};
}
if (!itpd.next_impl())
break;
}
}
IE_THROW() << "Primitive descriptor was not found for node " << getName() << " with type " << NameFromType(getType()) << ".";
}
int getExecIndex() const {
return execIndex;
}
@ -627,7 +578,7 @@ protected:
std::vector<NodeDesc> supportedPrimitiveDescriptors;
std::unordered_map<int, dnnl::memory> primArgs;
std::unordered_map<int, MemoryPtr> postOpsArgs;
Primitive prim;
dnnl::primitive prim;
std::vector<DnnlDesriptor> descs;
const GraphContext::CPtr context;
@ -733,21 +684,6 @@ private:
bool isEdgesEmpty(const std::vector<EdgeWeakPtr>& edges) const;
template <class PD, class D, typename FPD>
typename std::enable_if<!std::is_same<FPD, bool>::value, PD>::type
createPd(DnnlDesriptor desc) {
std::shared_ptr<D> selected_desc_ptr = desc;
std::shared_ptr<FPD> backward_prim_desc_ptr = desc;
return PD(*selected_desc_ptr, engine, *backward_prim_desc_ptr);
}
template <class PD, class D, typename FPD>
typename std::enable_if<std::is_same<FPD, bool>::value, PD>::type
createPd(DnnlDesriptor desc) {
std::shared_ptr<D> selected_desc_ptr = desc;
return PD(*selected_desc_ptr, engine);
}
enum LOOK { LOOK_UP = 1, LOOK_DOWN = 2 };
ConstantType checkConstant(LOOK look, std::vector<NodePtr>& checkNodes);

View File

@ -40,7 +40,7 @@ void DnnlExecutor::exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::st
IE_THROW() << "DnnlExecutor has reorder for output " << outReorder.first << ", but doesn't have destination memory";
}
}
(*execPrim).execute(strm, primArgs);
execPrim.execute(strm, primArgs);
for (auto &outReorder : outputReorders) {
outReorder.second.exec(primArgs[outReorder.first], outputMem[outReorder.first], strm);
}
@ -50,12 +50,12 @@ bool DnnlExecutor::needReordering() const {
return !inputReorders.empty() || !outputReorders.empty();
}
Primitive DnnlExecutor::getExecPrim() const {
dnnl::primitive DnnlExecutor::getExecPrim() const {
return execPrim;
}
const_dnnl_primitive_desc_t DnnlExecutor::getPrimitiveDesc() const {
return (*execPrim).get_primitive_desc();
return execPrim.get_primitive_desc();
}
dnnl::memory::desc DnnlExecutor::getSrcDesc() const {

View File

@ -5,7 +5,6 @@
#pragma once
#include <cpu_memory.h>
#include <primitive.h>
#include <onednn/iml_type_mapper.h>
namespace ov {
@ -30,7 +29,7 @@ class DnnlExecutor {
void exec(std::unordered_map<int, dnnl::memory> primArgs, dnnl::stream strm);
bool needReordering() const;
virtual ~DnnlExecutor() = default;
Primitive getExecPrim() const;
dnnl::primitive getExecPrim() const;
const_dnnl_primitive_desc_t getPrimitiveDesc() const;
dnnl::memory::desc getSrcDesc() const;
dnnl::memory::desc getWeightDesc() const;
@ -39,7 +38,7 @@ class DnnlExecutor {
protected:
DnnlExecutor() = default;
Primitive execPrim;
dnnl::primitive execPrim;
// key is the port number for the primitive that needs memory reordering
std::unordered_map<int, IntermReorder> inputReorders;
std::unordered_map<int, IntermReorder> outputReorders;

View File

@ -0,0 +1,68 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "reorder_prim.h"
#include <dnnl_extension_utils.h>
#include <dnnl_types.h>
#include <algorithm>
#include <common/primitive_hashing_utils.hpp>
#include <cpu/x64/cpu_isa_traits.hpp>
#include <memory>
#include <string>
#include "utils/general_utils.h"
namespace ov {
namespace intel_cpu {
struct ReorderKey {
dnnl::memory::desc src;
dnnl::memory::desc dest;
size_t hash() const;
bool operator==(const ReorderKey& rhs) const;
};
size_t ReorderKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, get_md_hash(src.data));
seed = hash_combine(seed, get_md_hash(dest.data));
return seed;
}
bool ReorderKey::operator==(const ReorderKey& rhs) const {
bool retVal = true;
retVal = src == rhs.src && dest == rhs.dest;
return retVal;
}
dnnl::reorder getReorderPrim(MultiCachePtr cache,
const dnnl::engine& engine,
const dnnl::memory::desc& src,
const dnnl::memory::desc& dest) {
auto builder = [&engine](const ReorderKey& key) {
dnnl::primitive_attr attr;
DEBUG_LOG(key.src, "->", key.dest);
dnnl::reorder::primitive_desc pd = dnnl::reorder::primitive_desc(engine, key.src, engine, key.dest, attr, true);
if (!pd) {
return dnnl::reorder();
}
return dnnl::reorder(pd);
};
ReorderKey key = {src, dest};
if (cache) {
auto result = cache->getOrCreate(key, builder);
return result.first;
}
return builder(key);
}
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,21 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_common.h>
#include <node.h>
#include <memory>
namespace ov {
namespace intel_cpu {
dnnl::reorder getReorderPrim(MultiCachePtr cache,
const dnnl::engine& engine,
const dnnl::memory::desc& src,
const dnnl::memory::desc& dest);
} // namespace intel_cpu
} // namespace ov

View File

@ -426,7 +426,7 @@ void Concat::prepareParams() {
}
auto primitive_desc = concat::primitive_desc(desc, static_cast<int>(axis), srcs_d, getEngine());
prim.reset(new concat(primitive_desc));
prim = concat(primitive_desc);
}
}
@ -557,7 +557,7 @@ void Concat::execute(dnnl::stream strm) {
mem_ags[DNNL_ARG_MULTIPLE_SRC + nonZeroInShapes] = srcMem.GetPrimitive();
nonZeroInShapes++;
}
(*prim).execute(strm, mem_ags);
prim.execute(strm, mem_ags);
}
}

View File

@ -1465,7 +1465,7 @@ Convolution::ConvolutionExecutor::ConvolutionExecutor(const dnnl::convolution_fo
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine) {
execPrim.reset(new dnnl::convolution_forward(pd));
execPrim = dnnl::convolution_forward(pd);
if (inMemDesc != pd.src_desc()) {
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, pd.src_desc(), engine)});

View File

@ -1045,7 +1045,7 @@ Deconvolution::DeconvExecutorDefault::DeconvExecutorDefault(const dnnl::convolut
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine) {
execPrim.reset(new dnnl::convolution_backward_data(pd));
execPrim = dnnl::convolution_backward_data(pd);
if (inMemDesc != pd.diff_dst_desc()) {
inputReorders.insert({DNNL_ARG_DIFF_DST, IntermReorder(inMemDesc, pd.diff_dst_desc(), engine)});
@ -1065,7 +1065,7 @@ Deconvolution::DeconvExecutorInt8::DeconvExecutorInt8(const dnnl::deconvolution_
const dnnl::memory::desc& weightMemDesc,
const dnnl::memory::desc& outMemDesc,
const dnnl::engine& engine) {
execPrim.reset(new dnnl::deconvolution_forward(pd));
execPrim = dnnl::deconvolution_forward(pd);
if (inMemDesc != pd.src_desc()) {
inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, pd.src_desc(), engine)});

View File

@ -892,11 +892,11 @@ bool FullyConnected::canBeExecutedInConv1x1() const {
}
FullyConnected::ExecutorInnerProduct::ExecutorInnerProduct(const dnnl::inner_product_forward::primitive_desc& pd) {
execPrim.reset(new dnnl::inner_product_forward(pd));
execPrim = dnnl::inner_product_forward(pd);
}
FullyConnected::ExecutorConv1x1::ExecutorConv1x1(const dnnl::convolution_forward::primitive_desc& pd) {
execPrim.reset(new dnnl::convolution_forward(pd));
execPrim = dnnl::convolution_forward(pd);
}
MemoryPtr FullyConnected::prepareWeightMemory(DnnlMemoryDescPtr weightDesc) {

View File

@ -241,7 +241,7 @@ void Interaction::execRef(dnnl::stream strm) {
float* scales = fqScales.empty() ? nullptr : fqScales.data();
for (int64_t start = 0; start < batchSize; start++) {
cat(reinterpret_cast<uint8_t*>(inputMemPtr->GetPtr()), inputPtrs, featureSizes, start, dataPrecision.size());
(*prim).execute(strm, mem_ags);
prim.execute(strm, mem_ags);
flat_triangle(reinterpret_cast<const uint8_t*>(outputMemPtr->GetPtr()),
reinterpret_cast<uint8_t*>(flatMemPtr->GetPtr()),
inputSizes,
@ -296,7 +296,7 @@ void Interaction::prepareParams() {
auto matmul_d = matmul::desc(src_md, weights_md, dst_md);
primitive_attr matmul_attr;
auto matmul_pd = matmul::primitive_desc(matmul_d, matmul_attr, getEngine());
prim.reset(new matmul(matmul_pd));
prim = matmul(matmul_pd);
featureSizes.assign(inputSizes, featureSize);
auto initMemoryPtr = [&](const InferenceEngine::Precision &prc, const intel_cpu::Shape& shape,
MemoryPtr& ptr) {

View File

@ -175,7 +175,7 @@ void Lrn::prepareParams() {
LrnKey key = {inpDesc, selected_pd->getImplementationType(), alg, size, k, alpha, beta};
auto engine = getEngine();
auto builder = [&engine](const LrnKey& key) -> std::shared_ptr<dnnl::primitive> {
auto builder = [&engine](const LrnKey& key) -> dnnl::primitive {
DnnlDesriptor desc(std::shared_ptr<dnnl::lrn_forward::desc>(
new dnnl::lrn_forward::desc(dnnl::prop_kind::forward_scoring, key.alg, key.inp0->getDnnlDesc(), key.size, key.alpha, key.beta, key.k)));
@ -190,9 +190,9 @@ void Lrn::prepareParams() {
break;
}
if (!itpd.next_impl())
return nullptr;
return dnnl::lrn_forward();
}
return std::make_shared<dnnl::lrn_forward>(prim_desc);
return dnnl::lrn_forward(prim_desc);
};
auto cache = context->getParamsCache();
@ -202,7 +202,7 @@ void Lrn::prepareParams() {
}
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto pd = prim.get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
auto src = srcMemPtr->GetPrimitive();

View File

@ -534,7 +534,7 @@ void MatMul::prepareParams() {
auto engine = getEngine();
auto builder = [&engine](const MatMulKey& key) -> std::shared_ptr<dnnl::primitive> {
auto builder = [&engine](const MatMulKey& key) -> dnnl::primitive {
std::shared_ptr<dnnl::matmul::desc> matmul_desc;
if (key.bias) {
@ -560,9 +560,9 @@ void MatMul::prepareParams() {
break;
}
if (!itpd.next_impl())
return nullptr;
return matmul();
}
return std::make_shared<matmul>(prim_desc);
return matmul(prim_desc);
};
auto cache = context->getParamsCache();
@ -574,7 +574,7 @@ void MatMul::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto pd = prim.get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();

View File

@ -352,7 +352,7 @@ void Pooling::prepareParams() {
alg,
selected_pd->getImplementationType()};
auto engine = getEngine();
auto builder = [&engine](const PoolingKey& key) -> std::shared_ptr<dnnl::primitive> {
auto builder = [&engine](const PoolingKey& key) -> dnnl::primitive {
auto desc_ptr = createDescriptorHelper(key.inp->getDnnlDesc(),
key.out->getDnnlDesc(),
key.alg,
@ -375,7 +375,7 @@ void Pooling::prepareParams() {
if (!itpd.next_impl())
break;
}
return std::make_shared<pooling_v2_forward>(prim_desc);
return pooling_v2_forward(prim_desc);
};
auto cache = context->getParamsCache();
@ -387,7 +387,7 @@ void Pooling::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto pd = prim.get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();

View File

@ -2008,7 +2008,7 @@ void Reduce::execute(dnnl::stream strm) {
if (jit_mode) {
if (is_hybrid_layout) {
dst_data = reinterpret_cast<uint8_t *>(prc_mem->get_data_handle());
dst_data = reinterpret_cast<uint8_t *>(prc_mem.get_data_handle());
}
reduce_type(src_data, dst_data, dst_size);
} else {
@ -2639,7 +2639,7 @@ inline void Reduce::create_working_memory() {
: (mayiuse(cpu::x64::avx512_core) ? memory::format_tag::nCdhw16c : memory::format_tag::nCdhw8c));
auto prc_dims = rank == 4 ? std::vector<size_t>{OB, OC, OH, OW} : std::vector<size_t>{OB, OC, OD, OH, OW};
auto desc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(prc_dims), DnnlExtensionUtils::IEPrecisionToDataType(output_prec), format);
prc_mem = std::make_shared<dnnl::memory>(desc, getEngine());
prc_mem = dnnl::memory(desc, getEngine());
dst_size = desc.get_size();
}

View File

@ -157,7 +157,7 @@ private:
std::vector<const void*> postOpsDataPtrs;
std::shared_ptr<dnnl::memory> prc_mem;
dnnl::memory prc_mem;
std::vector<uint8_t> vec_reduceDH_prc;
std::shared_ptr<jit_uni_reduce_kernel> reduce_kernel;

View File

@ -13,6 +13,7 @@
#include <cpu/x64/cpu_isa_traits.hpp>
#include "nodes/common/cpu_memcpy.h"
#include "nodes/common/cpu_convert.h"
#include "nodes/common/reorder_prim.h"
#include "convert.h"
#include <common/primitive_hashing_utils.hpp>
#include <utils/shape_inference/shape_inference_pass_through.hpp>
@ -23,33 +24,6 @@ using namespace InferenceEngine;
namespace ov {
namespace intel_cpu {
namespace node {
namespace {
struct ReorderKey {
dnnl::memory::desc src;
dnnl::memory::desc dest;
size_t hash() const;
bool operator==(const ReorderKey& rhs) const;
};
size_t ReorderKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, get_md_hash(src.data));
seed = hash_combine(seed, get_md_hash(dest.data));
return seed;
}
bool ReorderKey::operator==(const ReorderKey& rhs) const {
bool retVal = true;
retVal = src == rhs.src && dest == rhs.dest;
return retVal;
}
} // namespace
bool Reorder::isExecutable() const {
return Node::isExecutable() && !isOptimized;
@ -232,25 +206,8 @@ void Reorder::createReorderPrimitive(const dnnl::memory::desc& srcDesc,
src_desc = src_desc.permute_axes(src_permutation);
}
impl_desc_type impl_type = selectedPD->getImplementationType();
ReorderKey key = {src_desc, dst_blocked->GetPrimitive().get_desc()};
auto dst_desc = dst_blocked->GetPrimitive().get_desc();
auto builder = [&engine, &impl_type](const ReorderKey& key) -> std::shared_ptr<dnnl::primitive> {
dnnl::primitive_attr attr;
DEBUG_LOG(key.src, "->", key.dest);
reorder::primitive_desc pd = dnnl::reorder::primitive_desc(engine, key.src, engine, key.dest, attr, true);
if (!pd)
return nullptr;
auto info = pd.impl_info_str();
impl_type = parse_impl_name(info);
return std::make_shared<dnnl::reorder>(pd);
};
auto cache = context->getParamsCache();
std::pair<std::shared_ptr<dnnl::primitive>, CacheEntryBase::LookUpStatus> result{
nullptr,
CacheEntryBase::LookUpStatus::Miss};
// TODO: We should keep shape consistency for const and expected shape for node.
// If it requires reshape operation it should explicitly injected into graph.
//
@ -272,17 +229,18 @@ void Reorder::createReorderPrimitive(const dnnl::memory::desc& srcDesc,
newFormat);
src_blocked->Create(DnnlExtensionUtils::makeDescriptor(newDesc), srcPtr, false);
key.src = src_blocked->GetPrimitive().get_desc();
result = cache->getOrCreate(key, builder);
} else {
result = cache->getOrCreate(key, builder);
src_desc = src_blocked->GetPrimitive().get_desc();
}
if (!result.first) {
auto result = getReorderPrim(context->getParamsCache(), getEngine(), src_desc, dst_desc);
if (!result) {
IE_THROW() << "Cannot create reorder primitive: unsupported reorder case";
}
prim = result.first;
supportedPrimitiveDescriptors[0].setImplementationType(impl_type);
prim = result;
selectedPD->setImplementationType(
parse_impl_name(DnnlExtensionUtils::query_impl_info_str(prim.get_primitive_desc())));
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
primArgs = {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}};
@ -431,41 +389,15 @@ void Reorder::reorderData(const Memory &input, const Memory &output, MultiCacheP
auto copySize = output.GetSize();
cpu_memcpy(dstPtr, srcPtr, copySize);
} else {
auto getReorder = [] (MultiCachePtr& cache, const dnnl::memory& srcMemory, const dnnl::memory& dstMemory)
-> std::shared_ptr<dnnl::reorder> {
const auto& engine = dstMemory.get_engine();
auto builder = [&engine](const ReorderKey& key) -> std::shared_ptr<dnnl::reorder> {
dnnl::primitive_attr attr;
reorder::primitive_desc pd = dnnl::reorder::primitive_desc(engine, key.src, engine, key.dest, attr, true);
DEBUG_LOG(key.src, "->", key.dest);
if (!pd)
return nullptr;
return std::make_shared<dnnl::reorder>(pd);
};
std::shared_ptr<dnnl::reorder> reorder;
auto src_desc = srcMemory.get_desc();
auto dst_desc = dstMemory.get_desc();
ReorderKey key = {src_desc, dst_desc};
if (!cache) {
reorder = builder(key);
} else {
auto result = cache->getOrCreate(key, builder);
reorder = std::move(result.first);
}
return reorder;
};
std::shared_ptr<dnnl::reorder> pReorder;
dnnl::reorder reorder;
std::vector<uint8_t> tmpBuff;
auto srcMemory = input.GetPrimitive();
auto dstMemory = output.GetPrimitive();
auto engine = output.getEngine();
// try directly reorder
pReorder = getReorder(cache, srcMemory, dstMemory);
if (!pReorder) {
reorder = getReorderPrim(cache, dstMemory.get_engine(), srcMemory.get_desc(), dstMemory.get_desc());
if (!reorder) {
// try precision conversion then do the reorder
if (output.GetDataType() != input.GetDataType() && Convert::isSupportedDesc(input.getDesc()) &&
Convert::isSupportedDesc(output.getDesc())) {
@ -483,16 +415,16 @@ void Reorder::reorderData(const Memory &input, const Memory &output, MultiCacheP
tmpMem.Create(std::move(tmpDesc), tmpBuff.data());
srcMemory = tmpMem.GetPrimitive();
pReorder = getReorder(cache, srcMemory, dstMemory);
reorder = getReorderPrim(cache, dstMemory.get_engine(), srcMemory.get_desc(), dstMemory.get_desc());
}
if (!pReorder) {
if (!reorder) {
IE_THROW() << "No reorder available for the following tensor descriptors: "
<< input.getDesc().serializeFormat() << " and " << output.getDesc().serializeFormat();
}
}
if (pReorder) {
if (reorder) {
dnnl::stream loc_stream(engine, dnnl::stream::flags::in_order);
pReorder->execute(loc_stream, srcMemory, dstMemory);
reorder.execute(loc_stream, {{DNNL_ARG_FROM, srcMemory}, {DNNL_ARG_TO, dstMemory}});
} else {
IE_THROW() << "Could not make onednn reorder.";
}

View File

@ -1035,29 +1035,29 @@ void RNN::prepareParams() {
const auto attr = initPrimitiveAttr();
auto builder = [this, attr](const RNNKey& key) -> std::shared_ptr<dnnl::primitive> {
auto builder = [this, attr](const RNNKey& key) -> dnnl::primitive {
fillDescs();
if (key.cellType == dnnl::algorithm::vanilla_rnn) {
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, *attr, getEngine()));
return vanilla_rnn_forward(vanilla_rnn_forward::primitive_desc(*desc, *attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::vanilla_gru) {
std::shared_ptr<gru_forward::desc> desc = descs[0];
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, *attr, getEngine()));
return gru_forward(gru_forward::primitive_desc(*desc, *attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::lbr_gru) {
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, *attr, getEngine()));
return lbr_gru_forward(lbr_gru_forward::primitive_desc(*desc, *attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::vanilla_lstm) {
std::shared_ptr<lstm_forward::desc> desc = descs[0];
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, *attr, getEngine()));
return lstm_forward(lstm_forward::primitive_desc(*desc, *attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::vanilla_augru) {
std::shared_ptr<augru_forward::desc> desc = descs[0];
return std::make_shared<augru_forward>(augru_forward::primitive_desc(*desc, *attr, getEngine()));
return augru_forward(augru_forward::primitive_desc(*desc, *attr, getEngine()));
} else if (key.cellType == dnnl::algorithm::lbr_augru) {
std::shared_ptr<lbr_augru_forward::desc> desc = descs[0];
return std::make_shared<lbr_augru_forward>(lbr_augru_forward::primitive_desc(*desc, *attr, getEngine()));
return lbr_augru_forward(lbr_augru_forward::primitive_desc(*desc, *attr, getEngine()));
} else {
return nullptr;
return dnnl::primitive();
}
};
@ -1070,11 +1070,11 @@ void RNN::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto pd = prim.get_primitive_desc();
scratchpadMem = getScratchPadMem(pd);
if (!wasMemoryPrepared || wFormatWasChanged) {
auto pd = (*prim).get_primitive_desc();
auto pd = prim.get_primitive_desc();
auto query_weights_md = [&](int idx = 0) -> dnnl::memory::desc {
auto what = dnnl::convert_to_c(dnnl::query::weights_md);
const dnnl_memory_desc_t *cdesc = dnnl_primitive_desc_query_md(pd, what, idx);
@ -1143,7 +1143,7 @@ void RNN::execute(dnnl::stream strm) {
}
}
(*prim).execute(strm, args);
prim.execute(strm, args);
}
void RNN::executeDynamicImpl(dnnl::stream strm) {

View File

@ -149,7 +149,7 @@ void SoftMax::prepareParams() {
SoftmaxKey key = {inpDesc, selected_pd->getImplementationType(), axis};
auto engine = getEngine();
auto builder = [&engine](const SoftmaxKey& key) -> std::shared_ptr<dnnl::primitive> {
auto builder = [&engine](const SoftmaxKey& key) -> dnnl::primitive {
softmax_forward::primitive_desc prim_desc;
DnnlDesriptor desc(std::shared_ptr<softmax_forward::desc>(
new softmax_forward::desc(prop_kind::forward_scoring, key.inp0->getDnnlDesc(), key.axis)));
@ -169,9 +169,9 @@ void SoftMax::prepareParams() {
break;
}
if (!itpd.next_impl())
return nullptr;
return softmax_forward();
}
return std::make_shared<softmax_forward>(prim_desc);
return softmax_forward(prim_desc);
};
auto cache = context->getParamsCache();
@ -183,7 +183,7 @@ void SoftMax::prepareParams() {
prim = result.first;
auto pd = (*prim).get_primitive_desc();
auto pd = prim.get_primitive_desc();
auto scratchpadMem = getScratchPadMem(pd);
auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();

View File

@ -13,6 +13,7 @@
#include "utils/ngraph_utils.hpp"
#include "transformations/utils/utils.hpp"
#include "common/cpu_memcpy.h"
#include "common/reorder_prim.h"
#include <utils/shape_inference/shape_inference_internal_dyn.hpp>
using namespace dnnl;
@ -80,7 +81,7 @@ static void nullifyUndefinedDims(VectorDims& dims) {
class PortIteratorHelper : public PortMapHelper {
public:
PortIteratorHelper(const MemoryPtr &from, const MemoryPtr &to, bool sliced_src,
PortIteratorHelper(MultiCachePtr cache, const MemoryPtr &from, const MemoryPtr &to, bool sliced_src,
const PortMap &slice_rule, const dnnl::engine& eng)
: sliced_src(sliced_src) {
const auto &full_blob = sliced_src ? from : to;
@ -122,7 +123,7 @@ public:
mem_holder_src = from->GetPrimitive();
mem_holder_dst = chunk_mem;
}
reorder = {mem_holder_src, mem_holder_dst};
reorder = getReorderPrim(cache, mem_holder_dst.get_engine(), mem_holder_src.get_desc(), mem_holder_dst.get_desc());
}
void execute(dnnl::stream strm, int iter) override {
@ -132,7 +133,7 @@ public:
chunk_mem.set_data_handle(static_cast<uint8_t *>(full_mem.get_data_handle()) +
chunk_offset_in_byte + chunk_stride_in_byte * iter);
reorder.execute(strm, mem_holder_src, mem_holder_dst);
reorder.execute(strm, {{DNNL_ARG_FROM, mem_holder_src}, {DNNL_ARG_TO, mem_holder_dst}});
}
private:
@ -147,15 +148,15 @@ private:
class BackEdgePortHelper : public PortMapHelper {
public:
BackEdgePortHelper(const MemoryPtr &from, const MemoryPtr &to, const dnnl::engine& eng) {
BackEdgePortHelper(MultiCachePtr cache, const MemoryPtr &from, const MemoryPtr &to, const dnnl::engine& eng) {
mem_holder_src = from->GetPrimitive();
mem_holder_dst = to->GetPrimitive();
reorder = {mem_holder_src, mem_holder_dst};
reorder = getReorderPrim(cache, mem_holder_dst.get_engine(), mem_holder_src.get_desc(), mem_holder_dst.get_desc());
}
void execute(dnnl::stream strm, int iter = -1) override {
if (iter != 0) {
reorder.execute(strm, mem_holder_src, mem_holder_dst);
reorder.execute(strm, {{DNNL_ARG_FROM, mem_holder_src}, {DNNL_ARG_TO, mem_holder_dst}});
}
}
};
@ -258,16 +259,16 @@ void DynamicBuffer::init(const dnnl::engine& eng) {
count = std::accumulate(dims.begin(), dims.begin() + map_rule.axis, size_t(1), std::multiplies<size_t>());
len = std::accumulate(dims.begin() + map_rule.axis + 1, dims.end(), elem_size, std::multiplies<size_t>());
mem_holder_buffer.reset(new memory(src_desc, eng));
copy(reinterpret_cast<const uint8_t*>(from->GetPtr()), get_ptr(*mem_holder_buffer.get()), 0, 0, 1, from->GetSize());
mem_holder_buffer = memory(src_desc, eng);
copy(reinterpret_cast<const uint8_t*>(from->GetPtr()), get_ptr(mem_holder_buffer), 0, 0, 1, from->GetSize());
}
std::shared_ptr<dnnl::memory> DynamicBuffer::create_buffer(const dnnl::engine& eng) {
dnnl::memory DynamicBuffer::create_buffer(const dnnl::engine& eng) {
const auto axis = map_rule.axis;
const auto stride = map_rule.stride;
const auto abs_stride = std::abs(stride);
const auto old_desc = mem_holder_buffer->get_desc();
const auto old_desc = mem_holder_buffer.get_desc();
auto dims = old_desc.dims();
if (from->getStaticDims()[axis] != abs_stride)
@ -283,15 +284,15 @@ std::shared_ptr<dnnl::memory> DynamicBuffer::create_buffer(const dnnl::engine& e
buffer_offset_in_byte = from->GetPrimitive().get_desc().data.format_desc.blocking.strides[axis] * elem_size * abs_stride;
}
return std::make_shared<dnnl::memory>(new_buffer_desc, eng);
return dnnl::memory(new_buffer_desc, eng);
}
void DynamicBuffer::move_buffer(std::shared_ptr<dnnl::memory> new_buffer) {
void DynamicBuffer::move_buffer(dnnl::memory new_buffer) {
const auto axis = map_rule.axis;
const auto src_stride = mem_holder_buffer->get_desc().dims()[axis] * len;
const auto dst_stride = new_buffer->get_desc().dims()[axis] * len;
const auto src_stride = mem_holder_buffer.get_desc().dims()[axis] * len;
const auto dst_stride = new_buffer.get_desc().dims()[axis] * len;
copy(get_ptr(*mem_holder_buffer.get()), get_ptr(*new_buffer.get()) + buffer_offset_in_byte,
copy(get_ptr(mem_holder_buffer), get_ptr(new_buffer) + buffer_offset_in_byte,
src_stride, dst_stride, count, src_stride);
mem_holder_buffer = new_buffer;
}
@ -299,19 +300,19 @@ void DynamicBuffer::move_buffer(std::shared_ptr<dnnl::memory> new_buffer) {
void DynamicBuffer::move_data() {
const auto axis = map_rule.axis;
const auto src_stride = abs(map_rule.stride) * len;
const auto dst_stride = mem_holder_buffer->get_desc().dims()[axis] * len;
const auto dst_stride = mem_holder_buffer.get_desc().dims()[axis] * len;
copy(reinterpret_cast<const uint8_t*>(from->GetPtr()), get_ptr(*mem_holder_buffer.get()) + chunk_offset_in_byte,
copy(reinterpret_cast<const uint8_t*>(from->GetPtr()), get_ptr(mem_holder_buffer) + chunk_offset_in_byte,
src_stride, dst_stride, count, src_stride);
}
void DynamicBuffer::transfer(const Node* node) {
if (mem_holder_buffer) {
const auto desc = node->getBaseMemDescAtOutputPort(map_rule.from)->cloneWithNewDims(
DnnlExtensionUtils::convertToVectorDims(mem_holder_buffer->get_desc().dims()));
DnnlExtensionUtils::convertToVectorDims(mem_holder_buffer.get_desc().dims()));
redefineToMemories(to, desc);
copy(get_ptr(*mem_holder_buffer.get()), reinterpret_cast<uint8_t*>(to.front()->GetPtr()), 0, 0, 1, to.front()->GetSize());
copy(get_ptr(mem_holder_buffer), reinterpret_cast<uint8_t*>(to.front()->GetPtr()), 0, 0, 1, to.front()->GetSize());
} else {
VectorDims newDims = to.front()->GetShape().getDims();
nullifyUndefinedDims(newDims);
@ -320,7 +321,7 @@ void DynamicBuffer::transfer(const Node* node) {
redefineToMemories(to, desc);
}
mem_holder_buffer.reset();
mem_holder_buffer.reset(nullptr);
}
void DynamicBuffer::copy(const uint8_t* src, uint8_t* dst, const size_t src_stride, const size_t dst_stride, const size_t count, const size_t len) {
@ -590,10 +591,10 @@ void TensorIterator::prepareInputPorts() {
auto &to_mem = input_mems[map_rule.to].front(); // first memory is enough to access the shared underlying physical memory
if (map_rule.axis == -1)
first_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(from_mem, to_mem, eng));
first_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(context->getParamsCache(), from_mem, to_mem, eng));
else
before_mappers.emplace_back(
std::make_shared<PortIteratorHelper>(from_mem, to_mem, true, map_rule, eng));
std::make_shared<PortIteratorHelper>(context->getParamsCache(), from_mem, to_mem, true, map_rule, eng));
}
}
@ -604,9 +605,9 @@ void TensorIterator::prepareOutputPorts() {
auto &from_mem = output_mem[map_rule.to];
if (map_rule.axis == -1)
last_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(from_mem, to_mem, eng));
last_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(context->getParamsCache(), from_mem, to_mem, eng));
else
after_mappers.emplace_back(std::make_shared<PortIteratorHelper>(from_mem, to_mem, false, map_rule, eng));
after_mappers.emplace_back(std::make_shared<PortIteratorHelper>(context->getParamsCache(), from_mem, to_mem, false, map_rule, eng));
}
}
@ -616,7 +617,7 @@ void TensorIterator::prepareBackEdges() {
auto from_mem = output_mem[map_rule.from];
auto to_mem = input_mems[map_rule.to].front();
before_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(from_mem, to_mem, eng));
before_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(context->getParamsCache(), from_mem, to_mem, eng));
}
}
@ -630,7 +631,7 @@ void TensorIterator::prepareDynamicBackEdges() {
redefineToMemories(to_mems, from_mem->getDescPtr());
// first memory is enough to get common memory ptr
back_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(from_mem, to_mems.front(), eng));
back_mappers.emplace_back(std::make_shared<BackEdgePortHelper>(context->getParamsCache(), from_mem, to_mems.front(), eng));
}
}
@ -715,7 +716,7 @@ void TensorIterator::reshapeAndFillOutput(dnnl::stream strm) {
redefineToMemories(to_mems, desc);
if (!newShape.isDynamic()) {
BackEdgePortHelper mapper(from_mem, to_mems.front(), eng);
BackEdgePortHelper mapper(context->getParamsCache(), from_mem, to_mems.front(), eng);
mapper.execute(strm);
}
}

View File

@ -38,7 +38,7 @@ public:
virtual ~PortMapHelper() = default;
virtual void execute(dnnl::stream strm, int n_iter = -1) = 0;
protected:
dnnl::reorder reorder;
dnnl::primitive reorder;
dnnl::memory mem_holder_src;
dnnl::memory mem_holder_dst;
};
@ -74,8 +74,8 @@ private:
void init(const dnnl::engine& eng);
/* methods for resize and refill buffer */
std::shared_ptr<dnnl::memory> create_buffer(const dnnl::engine& eng);
void move_buffer(std::shared_ptr<dnnl::memory> new_buffer);
dnnl::memory create_buffer(const dnnl::engine& eng);
void move_buffer(dnnl::memory new_buffer);
void move_data();
static void copy(const uint8_t* src, uint8_t* dst, const size_t src_stride, const size_t dst_stride, const size_t count, const size_t len);
@ -91,7 +91,7 @@ private:
std::vector<MemoryPtr> to;
PortMap map_rule;
std::shared_ptr<dnnl::memory> mem_holder_buffer;
dnnl::memory mem_holder_buffer;
};
class TensorIterator : public Node {

View File

@ -4,6 +4,7 @@
#include "transpose.h"
#include "ie_parallel.hpp"
#include "nodes/common/reorder_prim.h"
#include <algorithm>
#include <string>
@ -16,31 +17,6 @@ using namespace InferenceEngine;
namespace ov {
namespace intel_cpu {
namespace node {
namespace {
struct TransposeAsReorderKey {
dnnl::memory::desc src;
dnnl::memory::desc dest;
size_t hash() const;
bool operator==(const TransposeAsReorderKey& rhs) const;
};
size_t TransposeAsReorderKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, get_md_hash(src.data));
seed = hash_combine(seed, get_md_hash(dest.data));
return seed;
}
bool TransposeAsReorderKey::operator==(const TransposeAsReorderKey& rhs) const {
bool retVal = true;
retVal = src == rhs.src && dest == rhs.dest;
return retVal;
}
} // namespace
bool Transpose::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
@ -147,48 +123,21 @@ bool Transpose::needPrepareParams() const {
void Transpose::prepareParams() {
if (performAsReorder) {
dnnl::primitive_attr attr;
const auto engine = getEngine();
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
// Transpose(order={0,3,1,2}) can be performed as Reorder(acdb=>abcd)
auto& srcMemPtr = getParentEdgeAt(INPUT_DATA_IDX)->getMemoryPtr();
MemoryPtr src_blocked = std::make_shared<Memory>(engine);
MemoryPtr dst_blocked = std::make_shared<Memory>(engine);
dst_blocked->Create(
DnnlExtensionUtils::makeDescriptor(dstMemPtr->GetDescWithType<DnnlMemoryDesc>()->getDnnlDesc()),
dstMemPtr->GetData(), false);
const auto newDims = dst_blocked->getStaticDims();
auto newDesc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(newDims),
dst_blocked->GetDataType(),
memory::format_tag::acdb);
src_blocked->Create(DnnlExtensionUtils::makeDescriptor(newDesc), srcMemPtr->GetData(), false);
impl_desc_type impl_type = getSelectedPrimitiveDescriptor()->getImplementationType();
TransposeAsReorderKey key = {src_blocked->GetPrimitive().get_desc(), dst_blocked->GetPrimitive().get_desc()};
auto builder = [&engine, &impl_type](const TransposeAsReorderKey& key) -> std::shared_ptr<dnnl::primitive> {
dnnl::primitive_attr attr;
reorder::primitive_desc pd = dnnl::reorder::primitive_desc(engine, key.src, engine, key.dest, attr, true);
if (!pd)
return nullptr;
auto info = pd.impl_info_str();
impl_type = parse_impl_name(info);
return std::make_shared<dnnl::reorder>(pd);
};
auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, builder);
if (!result.first) {
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto dstDesc = dstMemPtr->GetDescWithType<DnnlMemoryDesc>()->getDnnlDesc();
auto srcDesc = dnnl::memory::desc(dstDesc.dims(), dstDesc.data_type(), memory::format_tag::acdb);
auto result = getReorderPrim(context->getParamsCache(), getEngine(), srcDesc, dstDesc);
if (!result) {
IE_THROW() << "Reorder primitive descriptor was not found for Transpose node " << getName() << ".";
}
prim = result;
prim = result.first;
getSelectedPrimitiveDescriptor()->setImplementationType(
parse_impl_name(DnnlExtensionUtils::query_impl_info_str(prim.get_primitive_desc())));
supportedPrimitiveDescriptors[0].setImplementationType(impl_type);
primArgs = {{DNNL_ARG_SRC, getParentEdgesAtPort(INPUT_DATA_IDX)[0]->getMemoryPtr()->GetPrimitive()},
{DNNL_ARG_DST, getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive()}};
primArgs = {{DNNL_ARG_SRC, srcMemPtr->GetPrimitive()}, {DNNL_ARG_DST, dstMemPtr->GetPrimitive()}};
return;
}
@ -358,7 +307,7 @@ void Transpose::optimizedExecute(const int MB, const MemoryPtr& srcMemPtr, Memor
void Transpose::execute(dnnl::stream strm) {
if (prim) {
(*prim).execute(strm, primArgs);
prim.execute(strm, primArgs);
} else if (execPtr) {
auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto &srcMemPtr = getParentEdgeAt(INPUT_DATA_IDX)->getMemoryPtr();

View File

@ -1,31 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <dnnl_types.h>
#include "primitive.h"
namespace ov {
namespace intel_cpu {
Primitive::Primitive() {}
Primitive::operator bool() const {
return prim ? true : false;
}
dnnl::primitive Primitive::operator*() const {
return *prim;
}
void Primitive::reset(dnnl::primitive* primitive) {
prim.reset(primitive);
}
Primitive &Primitive::operator=(const std::shared_ptr<dnnl::primitive>& primitive) {
prim = primitive;
return *this;
}
} // namespace intel_cpu
} // namespace ov

View File

@ -1,30 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <functional>
#include <ie_common.h>
#include <vector>
#include <memory>
#include "onednn/dnnl.h"
namespace ov {
namespace intel_cpu {
class Primitive {
public:
Primitive();
operator bool() const;
Primitive& operator=(const std::shared_ptr<dnnl::primitive>& primitive);
dnnl::primitive operator*() const;
void reset(dnnl::primitive* primitive);
private:
std::shared_ptr<dnnl::primitive> prim;
};
} // namespace intel_cpu
} // namespace ov