[CPU] Impl extract_image_patches cache (#9525)
This commit is contained in:
parent
1a3d0adb3e
commit
986f0eaac6
@ -12,6 +12,7 @@
|
|||||||
#include "list.hpp"
|
#include "list.hpp"
|
||||||
#include <cpu/x64/jit_generator.hpp>
|
#include <cpu/x64/jit_generator.hpp>
|
||||||
#include "caseless.hpp"
|
#include "caseless.hpp"
|
||||||
|
#include <common/primitive_hashing_utils.hpp>
|
||||||
|
|
||||||
using namespace MKLDNNPlugin;
|
using namespace MKLDNNPlugin;
|
||||||
using namespace InferenceEngine;
|
using namespace InferenceEngine;
|
||||||
@ -290,6 +291,40 @@ bool MKLDNNExtractImagePatchesNode::isSupportedOperation(const std::shared_ptr<c
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct ExtractImagePatchesKey {
|
||||||
|
VectorDims inDims;
|
||||||
|
VectorDims outDims;
|
||||||
|
VectorDims kSizes;
|
||||||
|
VectorDims strides;
|
||||||
|
VectorDims rates;
|
||||||
|
MKLDNNExtractImagePatchesNode::ExtImgPatcherPadType padType;
|
||||||
|
size_t prcSize;
|
||||||
|
size_t hash() const;
|
||||||
|
bool operator==(const ExtractImagePatchesKey& rhs) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t ExtractImagePatchesKey::hash() const {
|
||||||
|
using namespace dnnl::impl::primitive_hashing;
|
||||||
|
using namespace dnnl::impl;
|
||||||
|
size_t seed = 0;
|
||||||
|
seed = get_vector_hash(seed, inDims);
|
||||||
|
seed = get_vector_hash(seed, outDims);
|
||||||
|
seed = get_vector_hash(seed, kSizes);
|
||||||
|
seed = get_vector_hash(seed, strides);
|
||||||
|
seed = get_vector_hash(seed, rates);
|
||||||
|
seed = hash_combine(seed, padType);
|
||||||
|
seed = hash_combine(seed, prcSize);
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExtractImagePatchesKey::operator==(const ExtractImagePatchesKey& rhs) const {
|
||||||
|
bool result = inDims == rhs.inDims && outDims == rhs.outDims && kSizes == rhs.kSizes && strides == rhs.strides &&
|
||||||
|
rates == rhs.rates && padType == rhs.padType && prcSize == rhs.prcSize;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
MKLDNNExtractImagePatchesNode::MKLDNNExtractImagePatchesNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
|
MKLDNNExtractImagePatchesNode::MKLDNNExtractImagePatchesNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
|
||||||
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
|
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
|
||||||
std::string errorMessage;
|
std::string errorMessage;
|
||||||
@ -340,11 +375,30 @@ void MKLDNNExtractImagePatchesNode::prepareParams() {
|
|||||||
const auto& in_dims = getParentEdgeAt(0)->getMemory().getStaticDims();
|
const auto& in_dims = getParentEdgeAt(0)->getMemory().getStaticDims();
|
||||||
const auto& out_dims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
const auto& out_dims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||||
const auto prcSize = getOriginalInputPrecisionAtPort(0).size();
|
const auto prcSize = getOriginalInputPrecisionAtPort(0).size();
|
||||||
if (mayiuse(x64::sse41)) {
|
ExtractImagePatchesKey key = {in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize};
|
||||||
execPtr = std::make_shared<ExtractImagePatchesJitExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
|
const auto isJit = mayiuse(x64::sse41);
|
||||||
} else {
|
auto buildExecutor = [&isJit](const ExtractImagePatchesKey& key) -> executorPtr {
|
||||||
execPtr = std::make_shared<ExtractImagePatchesRefExecutor>(in_dims, out_dims, _ksizes, _strides, _rates, _auto_pad, prcSize);
|
if (isJit) {
|
||||||
}
|
return std::make_shared<ExtractImagePatchesJitExecutor>(key.inDims,
|
||||||
|
key.outDims,
|
||||||
|
key.kSizes,
|
||||||
|
key.strides,
|
||||||
|
key.rates,
|
||||||
|
key.padType,
|
||||||
|
key.prcSize);
|
||||||
|
} else {
|
||||||
|
return std::make_shared<ExtractImagePatchesRefExecutor>(key.inDims,
|
||||||
|
key.outDims,
|
||||||
|
key.kSizes,
|
||||||
|
key.strides,
|
||||||
|
key.rates,
|
||||||
|
key.padType,
|
||||||
|
key.prcSize);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto cache = getRuntimeCache();
|
||||||
|
auto result = cache->getOrCreate(key, buildExecutor);
|
||||||
|
execPtr = result.first;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MKLDNNExtractImagePatchesNode::initSupportedPrimitiveDescriptors() {
|
void MKLDNNExtractImagePatchesNode::initSupportedPrimitiveDescriptors() {
|
||||||
|
@ -52,14 +52,13 @@ public:
|
|||||||
void prepareParams() override;
|
void prepareParams() override;
|
||||||
|
|
||||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||||
|
|
||||||
private:
|
|
||||||
enum class ExtImgPatcherPadType {
|
enum class ExtImgPatcherPadType {
|
||||||
VALID,
|
VALID,
|
||||||
SAME_LOWER,
|
SAME_LOWER,
|
||||||
SAME_UPPER
|
SAME_UPPER
|
||||||
};
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
std::vector<size_t> _ksizes;
|
std::vector<size_t> _ksizes;
|
||||||
std::vector<size_t> _strides;
|
std::vector<size_t> _strides;
|
||||||
std::vector<size_t> _rates;
|
std::vector<size_t> _rates;
|
||||||
|
@ -79,13 +79,13 @@ const std::vector<InputShape> inputShapes = {
|
|||||||
// dynamic
|
// dynamic
|
||||||
{-1, -1, -1, -1},
|
{-1, -1, -1, -1},
|
||||||
// static
|
// static
|
||||||
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}}
|
{{2, 3, 13, 37}, {6, 4, 14, 14}, {8, 12, 15, 16}, {2, 3, 13, 37}}
|
||||||
},
|
},
|
||||||
InputShape{
|
InputShape{
|
||||||
// dynamic
|
// dynamic
|
||||||
{{5, 15}, {6, 17}, {10, 15}, {13, 16}},
|
{{5, 15}, {6, 17}, {10, 15}, {13, 16}},
|
||||||
// static
|
// static
|
||||||
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}}
|
{{5, 17, 10, 15}, {15, 10, 12, 13}, {10, 10, 15, 16}, {5, 17, 10, 15}}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user