[CPU] Impl extract_image_patches cache (#9525)

This commit is contained in:
Zhang Yi 2022-01-11 16:03:10 +08:00 committed by GitHub
parent 1a3d0adb3e
commit 986f0eaac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 9 deletions

View File

@ -12,6 +12,7 @@
#include "list.hpp" #include "list.hpp"
#include <cpu/x64/jit_generator.hpp> #include <cpu/x64/jit_generator.hpp>
#include "caseless.hpp" #include "caseless.hpp"
#include <common/primitive_hashing_utils.hpp>
using namespace MKLDNNPlugin; using namespace MKLDNNPlugin;
using namespace InferenceEngine; using namespace InferenceEngine;
@ -290,6 +291,40 @@ bool MKLDNNExtractImagePatchesNode::isSupportedOperation(const std::shared_ptr<c
return true; return true;
} }
namespace {
struct ExtractImagePatchesKey {
VectorDims inDims;
VectorDims outDims;
VectorDims kSizes;
VectorDims strides;
VectorDims rates;
MKLDNNExtractImagePatchesNode::ExtImgPatcherPadType padType;
size_t prcSize;
size_t hash() const;
bool operator==(const ExtractImagePatchesKey& rhs) const;
};
size_t ExtractImagePatchesKey::hash() const {
using namespace dnnl::impl::primitive_hashing;
using namespace dnnl::impl;
size_t seed = 0;
seed = get_vector_hash(seed, inDims);
seed = get_vector_hash(seed, outDims);
seed = get_vector_hash(seed, kSizes);
seed = get_vector_hash(seed, strides);
seed = get_vector_hash(seed, rates);
seed = hash_combine(seed, padType);
seed = hash_combine(seed, prcSize);
return seed;
}
bool ExtractImagePatchesKey::operator==(const ExtractImagePatchesKey& rhs) const {
bool result = inDims == rhs.inDims && outDims == rhs.outDims && kSizes == rhs.kSizes && strides == rhs.strides &&
rates == rhs.rates && padType == rhs.padType && prcSize == rhs.prcSize;
return result;
}
} // namespace
MKLDNNExtractImagePatchesNode::MKLDNNExtractImagePatchesNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng, MKLDNNExtractImagePatchesNode::MKLDNNExtractImagePatchesNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) { MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
std::string errorMessage; std::string errorMessage;
@ -340,11 +375,30 @@ void MKLDNNExtractImagePatchesNode::prepareParams() {
const auto& in_dims = getParentEdgeAt(0)->getMemory().getStaticDims(); const auto& in_dims = getParentEdgeAt(0)->getMemory().getStaticDims();
const auto& out_dims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims(); const auto& out_dims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
const auto prcSize = getOriginalInputPrecisionAtPort(0).size(); const auto prcSize = getOriginalInputPrecisionAtPort(0).size();
if (mayiuse(x64::sse41)) { ExtractImagePatchesKey key = {in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize};
execPtr = std::make_shared<ExtractImagePatchesJitExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize); const auto isJit = mayiuse(x64::sse41);
} else { auto buildExecutor = [&isJit](const ExtractImagePatchesKey& key) -> executorPtr {
execPtr = std::make_shared<ExtractImagePatchesRefExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize); if (isJit) {
} return std::make_shared<ExtractImagePatchesJitExecutor>(key.inDims,
key.outDims,
key.kSizes,
key.strides,
key.rates,
key.padType,
key.prcSize);
} else {
return std::make_shared<ExtractImagePatchesRefExecutor>(key.inDims,
key.outDims,
key.kSizes,
key.strides,
key.rates,
key.padType,
key.prcSize);
}
};
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, buildExecutor);
execPtr = result.first;
} }
void MKLDNNExtractImagePatchesNode::initSupportedPrimitiveDescriptors() { void MKLDNNExtractImagePatchesNode::initSupportedPrimitiveDescriptors() {

View File

@ -52,14 +52,13 @@ public:
void prepareParams() override; void prepareParams() override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
private:
enum class ExtImgPatcherPadType { enum class ExtImgPatcherPadType {
VALID, VALID,
SAME_LOWER, SAME_LOWER,
SAME_UPPER SAME_UPPER
}; };
private:
std::vector<size_t> _ksizes; std::vector<size_t> _ksizes;
std::vector<size_t> _strides; std::vector<size_t> _strides;
std::vector<size_t> _rates; std::vector<size_t> _rates;

View File

@ -79,13 +79,13 @@ const std::vector<InputShape> inputShapes = {
// dynamic // dynamic
{-1, -1, -1, -1}, {-1, -1, -1, -1},
// static // static
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}} {{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}, {2, 3, 13, 37}}
}, },
InputShape{ InputShape{
// dynamic // dynamic
{{5, 15}, {6, 17}, {10, 15}, {13, 16}}, {{5, 15}, {6, 17}, {10, 15}, {13, 16}},
// static // static
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}} {{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}, {5, 17, 10, 15}}
}, },
}; };