[CPU] Impl extract_image_patches cache (#9525)
This commit is contained in:
parent
1a3d0adb3e
commit
986f0eaac6
@ -12,6 +12,7 @@
|
||||
#include "list.hpp"
|
||||
#include <cpu/x64/jit_generator.hpp>
|
||||
#include "caseless.hpp"
|
||||
#include <common/primitive_hashing_utils.hpp>
|
||||
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
@ -290,6 +291,40 @@ bool MKLDNNExtractImagePatchesNode::isSupportedOperation(const std::shared_ptr<c
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ExtractImagePatchesKey {
|
||||
VectorDims inDims;
|
||||
VectorDims outDims;
|
||||
VectorDims kSizes;
|
||||
VectorDims strides;
|
||||
VectorDims rates;
|
||||
MKLDNNExtractImagePatchesNode::ExtImgPatcherPadType padType;
|
||||
size_t prcSize;
|
||||
size_t hash() const;
|
||||
bool operator==(const ExtractImagePatchesKey& rhs) const;
|
||||
};
|
||||
|
||||
size_t ExtractImagePatchesKey::hash() const {
|
||||
using namespace dnnl::impl::primitive_hashing;
|
||||
using namespace dnnl::impl;
|
||||
size_t seed = 0;
|
||||
seed = get_vector_hash(seed, inDims);
|
||||
seed = get_vector_hash(seed, outDims);
|
||||
seed = get_vector_hash(seed, kSizes);
|
||||
seed = get_vector_hash(seed, strides);
|
||||
seed = get_vector_hash(seed, rates);
|
||||
seed = hash_combine(seed, padType);
|
||||
seed = hash_combine(seed, prcSize);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool ExtractImagePatchesKey::operator==(const ExtractImagePatchesKey& rhs) const {
|
||||
bool result = inDims == rhs.inDims && outDims == rhs.outDims && kSizes == rhs.kSizes && strides == rhs.strides &&
|
||||
rates == rhs.rates && padType == rhs.padType && prcSize == rhs.prcSize;
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MKLDNNExtractImagePatchesNode::MKLDNNExtractImagePatchesNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
|
||||
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
|
||||
std::string errorMessage;
|
||||
@ -340,11 +375,30 @@ void MKLDNNExtractImagePatchesNode::prepareParams() {
|
||||
const auto& in_dims = getParentEdgeAt(0)->getMemory().getStaticDims();
|
||||
const auto& out_dims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
const auto prcSize = getOriginalInputPrecisionAtPort(0).size();
|
||||
if (mayiuse(x64::sse41)) {
|
||||
execPtr = std::make_shared<ExtractImagePatchesJitExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
|
||||
ExtractImagePatchesKey key = {in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize};
|
||||
const auto isJit = mayiuse(x64::sse41);
|
||||
auto buildExecutor = [&isJit](const ExtractImagePatchesKey& key) -> executorPtr {
|
||||
if (isJit) {
|
||||
return std::make_shared<ExtractImagePatchesJitExecutor>(key.inDims,
|
||||
key.outDims,
|
||||
key.kSizes,
|
||||
key.strides,
|
||||
key.rates,
|
||||
key.padType,
|
||||
key.prcSize);
|
||||
} else {
|
||||
execPtr = std::make_shared<ExtractImagePatchesRefExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
|
||||
return std::make_shared<ExtractImagePatchesRefExecutor>(key.inDims,
|
||||
key.outDims,
|
||||
key.kSizes,
|
||||
key.strides,
|
||||
key.rates,
|
||||
key.padType,
|
||||
key.prcSize);
|
||||
}
|
||||
};
|
||||
auto cache = getRuntimeCache();
|
||||
auto result = cache->getOrCreate(key, buildExecutor);
|
||||
execPtr = result.first;
|
||||
}
|
||||
|
||||
void MKLDNNExtractImagePatchesNode::initSupportedPrimitiveDescriptors() {
|
||||
|
@ -52,14 +52,13 @@ public:
|
||||
void prepareParams() override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
enum class ExtImgPatcherPadType {
|
||||
VALID,
|
||||
SAME_LOWER,
|
||||
SAME_UPPER
|
||||
};
|
||||
|
||||
private:
|
||||
std::vector<size_t> _ksizes;
|
||||
std::vector<size_t> _strides;
|
||||
std::vector<size_t> _rates;
|
||||
|
@ -79,13 +79,13 @@ const std::vector<InputShape> inputShapes = {
|
||||
// dynamic
|
||||
{-1, -1, -1, -1},
|
||||
// static
|
||||
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}}
|
||||
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}, {2, 3, 13, 37}}
|
||||
},
|
||||
InputShape{
|
||||
// dynamic
|
||||
{{5, 15}, {6, 17}, {10, 15}, {13, 16}},
|
||||
// static
|
||||
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}}
|
||||
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}, {5, 17, 10, 15}}
|
||||
},
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user