[CPU] Enable cache for LRN (#9589)

This commit is contained in:
Mang Guo
2022-01-13 14:54:00 +08:00
committed by GitHub
parent c3d2af7501
commit f562e5572f
2 changed files with 70 additions and 16 deletions

View File

@@ -8,10 +8,54 @@
#include <ngraph/opsets/opset1.hpp>
#include <memory_desc/cpu_memory_desc_utils.h>
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include <common/primitive_hashing_utils.hpp>
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
namespace {
struct LrnKey {
DnnlMemoryDescCPtr inp0;
impl_desc_type implType;
mkldnn::algorithm alg;
size_t size;
int k;
float alpha;
float beta;
size_t hash() const;
bool operator==(const LrnKey& rhs) const;
};
size_t LrnKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, get_md_hash(inp0->getDnnlDesc().data));
seed = hash_combine(seed, implType);
seed = hash_combine(seed, alg);
seed = hash_combine(seed, size);
seed = hash_combine(seed, k);
seed = hash_combine(seed, alpha);
seed = hash_combine(seed, beta);
return seed;
}
bool LrnKey::operator==(const LrnKey &rhs) const {
bool retVal = true;
if (inp0 != rhs.inp0) {
retVal = retVal && inp0 && rhs.inp0 && inp0->getDnnlDesc() == rhs.inp0->getDnnlDesc();
}
retVal = retVal && implType == rhs.implType && alg == rhs.alg && alg == rhs.alg && size == rhs.size && k == rhs.k &&
alpha == rhs.alpha && beta == rhs.beta;
return retVal;
}
} // namespace
bool MKLDNNLrnNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
auto lrn = ngraph::as_type_ptr<const ngraph::opset1::LRN>(op);
@@ -123,24 +167,34 @@ void MKLDNNLrnNode::prepareParams() {
IE_THROW() << errorPrefix << "preferable primitive descriptor did not set";
auto inpDesc = getParentEdgeAt(0)->getMemory().GetDescWithType<DnnlMemoryDesc>();
const auto& in_candidate = inpDesc->getDnnlDesc();
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, in_candidate, size, alpha, beta, k)));
mkldnn::lrn_forward::primitive_desc prim_desc;
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(getEngine());
while (static_cast<bool>(itpd)) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
LrnKey key = {inpDesc, selected_pd->getImplementationType(), alg, size, k, alpha, beta};
auto engine = getEngine();
if (impl_type == selected_pd->getImplementationType()) {
prim_desc = itpd.get();
break;
auto builder = [&engine](const LrnKey& key) -> std::shared_ptr<mkldnn::primitive> {
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, key.alg, key.inp0->getDnnlDesc(), key.size, key.alpha, key.beta, key.k)));
mkldnn::lrn_forward::primitive_desc prim_desc;
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine);
while (static_cast<bool>(itpd)) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
if (impl_type == key.implType) {
prim_desc = itpd.get();
break;
}
if (!itpd.next_impl())
return nullptr;
}
if (!itpd.next_impl())
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
}
return std::make_shared<mkldnn::lrn_forward>(prim_desc);
};
prim.reset(new mkldnn::lrn_forward(prim_desc));
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, builder);
if (!result.first) {
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
}
prim = result.first;
auto src = srcMemPtr->GetPrimitive();
auto dst = dstMemPtr->GetPrimitive();

View File

@@ -91,13 +91,13 @@ const std::vector<InputShape> inputShapes = {
// dynamic
{-1, -1, -1, -1},
// static
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}, {10, 10, 3, 8}}
},
InputShape{
// dynamic
{{1, 15}, {3, 10}, {3, 7}, {5, 8}},
// static
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}, {10, 10, 3, 8}}
},
};