[CPU] Enable cache for LRN (#9589)
This commit is contained in:
@@ -8,10 +8,54 @@
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <memory_desc/cpu_memory_desc_utils.h>
|
||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||
#include <common/primitive_hashing_utils.hpp>
|
||||
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace {
|
||||
struct LrnKey {
|
||||
DnnlMemoryDescCPtr inp0;
|
||||
impl_desc_type implType;
|
||||
mkldnn::algorithm alg;
|
||||
size_t size;
|
||||
int k;
|
||||
float alpha;
|
||||
float beta;
|
||||
|
||||
size_t hash() const;
|
||||
bool operator==(const LrnKey& rhs) const;
|
||||
};
|
||||
|
||||
size_t LrnKey::hash() const {
|
||||
using namespace dnnl::impl;
|
||||
using namespace dnnl::impl::primitive_hashing;
|
||||
|
||||
size_t seed = 0;
|
||||
|
||||
seed = hash_combine(seed, get_md_hash(inp0->getDnnlDesc().data));
|
||||
seed = hash_combine(seed, implType);
|
||||
seed = hash_combine(seed, alg);
|
||||
seed = hash_combine(seed, size);
|
||||
seed = hash_combine(seed, k);
|
||||
seed = hash_combine(seed, alpha);
|
||||
seed = hash_combine(seed, beta);
|
||||
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool LrnKey::operator==(const LrnKey &rhs) const {
|
||||
bool retVal = true;
|
||||
if (inp0 != rhs.inp0) {
|
||||
retVal = retVal && inp0 && rhs.inp0 && inp0->getDnnlDesc() == rhs.inp0->getDnnlDesc();
|
||||
}
|
||||
|
||||
retVal = retVal && implType == rhs.implType && alg == rhs.alg && alg == rhs.alg && size == rhs.size && k == rhs.k &&
|
||||
alpha == rhs.alpha && beta == rhs.beta;
|
||||
return retVal;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool MKLDNNLrnNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
auto lrn = ngraph::as_type_ptr<const ngraph::opset1::LRN>(op);
|
||||
@@ -123,24 +167,34 @@ void MKLDNNLrnNode::prepareParams() {
|
||||
IE_THROW() << errorPrefix << "preferable primitive descriptor did not set";
|
||||
|
||||
auto inpDesc = getParentEdgeAt(0)->getMemory().GetDescWithType<DnnlMemoryDesc>();
|
||||
const auto& in_candidate = inpDesc->getDnnlDesc();
|
||||
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
|
||||
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, alg, in_candidate, size, alpha, beta, k)));
|
||||
|
||||
mkldnn::lrn_forward::primitive_desc prim_desc;
|
||||
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(getEngine());
|
||||
while (static_cast<bool>(itpd)) {
|
||||
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
|
||||
LrnKey key = {inpDesc, selected_pd->getImplementationType(), alg, size, k, alpha, beta};
|
||||
auto engine = getEngine();
|
||||
|
||||
if (impl_type == selected_pd->getImplementationType()) {
|
||||
prim_desc = itpd.get();
|
||||
break;
|
||||
auto builder = [&engine](const LrnKey& key) -> std::shared_ptr<mkldnn::primitive> {
|
||||
MKLDNNDescriptor desc(std::shared_ptr<mkldnn::lrn_forward::desc>(
|
||||
new mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring, key.alg, key.inp0->getDnnlDesc(), key.size, key.alpha, key.beta, key.k)));
|
||||
|
||||
mkldnn::lrn_forward::primitive_desc prim_desc;
|
||||
dnnl::primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine);
|
||||
while (static_cast<bool>(itpd)) {
|
||||
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
|
||||
if (impl_type == key.implType) {
|
||||
prim_desc = itpd.get();
|
||||
break;
|
||||
}
|
||||
if (!itpd.next_impl())
|
||||
return nullptr;
|
||||
}
|
||||
if (!itpd.next_impl())
|
||||
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||
}
|
||||
return std::make_shared<mkldnn::lrn_forward>(prim_desc);
|
||||
};
|
||||
|
||||
prim.reset(new mkldnn::lrn_forward(prim_desc));
|
||||
auto cache = getRuntimeCache();
|
||||
auto result = cache->getOrCreate(key, builder);
|
||||
if (!result.first) {
|
||||
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||
}
|
||||
prim = result.first;
|
||||
|
||||
auto src = srcMemPtr->GetPrimitive();
|
||||
auto dst = dstMemPtr->GetPrimitive();
|
||||
|
||||
@@ -91,13 +91,13 @@ const std::vector<InputShape> inputShapes = {
|
||||
// dynamic
|
||||
{-1, -1, -1, -1},
|
||||
// static
|
||||
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
|
||||
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}, {10, 10, 3, 8}}
|
||||
},
|
||||
InputShape{
|
||||
// dynamic
|
||||
{{1, 15}, {3, 10}, {3, 7}, {5, 8}},
|
||||
// static
|
||||
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}}
|
||||
{{15, 5, 7, 8}, {10, 10, 3, 8}, {1, 3, 5, 5}, {10, 10, 3, 8}}
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user