[CPU] Impl extract_image_patches cache (#9525)

This commit is contained in:
Zhang Yi 2022-01-11 16:03:10 +08:00 committed by GitHub
parent 1a3d0adb3e
commit 986f0eaac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 9 deletions

View File

@ -12,6 +12,7 @@
#include "list.hpp"
#include <cpu/x64/jit_generator.hpp>
#include "caseless.hpp"
#include <common/primitive_hashing_utils.hpp>
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
@ -290,6 +291,40 @@ bool MKLDNNExtractImagePatchesNode::isSupportedOperation(const std::shared_ptr<c
return true;
}
namespace {
struct ExtractImagePatchesKey {
VectorDims inDims;
VectorDims outDims;
VectorDims kSizes;
VectorDims strides;
VectorDims rates;
MKLDNNExtractImagePatchesNode::ExtImgPatcherPadType padType;
size_t prcSize;
size_t hash() const;
bool operator==(const ExtractImagePatchesKey& rhs) const;
};
size_t ExtractImagePatchesKey::hash() const {
using namespace dnnl::impl::primitive_hashing;
using namespace dnnl::impl;
size_t seed = 0;
seed = get_vector_hash(seed, inDims);
seed = get_vector_hash(seed, outDims);
seed = get_vector_hash(seed, kSizes);
seed = get_vector_hash(seed, strides);
seed = get_vector_hash(seed, rates);
seed = hash_combine(seed, padType);
seed = hash_combine(seed, prcSize);
return seed;
}
bool ExtractImagePatchesKey::operator==(const ExtractImagePatchesKey& rhs) const {
bool result = inDims == rhs.inDims && outDims == rhs.outDims && kSizes == rhs.kSizes && strides == rhs.strides &&
rates == rhs.rates && padType == rhs.padType && prcSize == rhs.prcSize;
return result;
}
} // namespace
MKLDNNExtractImagePatchesNode::MKLDNNExtractImagePatchesNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
std::string errorMessage;
@ -340,11 +375,30 @@ void MKLDNNExtractImagePatchesNode::prepareParams() {
const auto& in_dims = getParentEdgeAt(0)->getMemory().getStaticDims();
const auto& out_dims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
const auto prcSize = getOriginalInputPrecisionAtPort(0).size();
if (mayiuse(x64::sse41)) {
execPtr = std::make_shared<ExtractImagePatchesJitExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
ExtractImagePatchesKey key = {in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize};
const auto isJit = mayiuse(x64::sse41);
auto buildExecutor = [&isJit](const ExtractImagePatchesKey& key) -> executorPtr {
if (isJit) {
return std::make_shared<ExtractImagePatchesJitExecutor>(key.inDims,
key.outDims,
key.kSizes,
key.strides,
key.rates,
key.padType,
key.prcSize);
} else {
execPtr = std::make_shared<ExtractImagePatchesRefExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
return std::make_shared<ExtractImagePatchesRefExecutor>(key.inDims,
key.outDims,
key.kSizes,
key.strides,
key.rates,
key.padType,
key.prcSize);
}
};
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, buildExecutor);
execPtr = result.first;
}
void MKLDNNExtractImagePatchesNode::initSupportedPrimitiveDescriptors() {

View File

@ -52,14 +52,13 @@ public:
void prepareParams() override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
private:
enum class ExtImgPatcherPadType {
VALID,
SAME_LOWER,
SAME_UPPER
};
private:
std::vector<size_t> _ksizes;
std::vector<size_t> _strides;
std::vector<size_t> _rates;

View File

@ -79,13 +79,13 @@ const std::vector<InputShape> inputShapes = {
// dynamic
{-1, -1, -1, -1},
// static
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}}
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}, {2, 3, 13, 37}}
},
InputShape{
// dynamic
{{5, 15}, {6, 17}, {10, 15}, {13, 16}},
// static
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}}
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}, {5, 17, 10, 15}}
},
};