[CPU] Enable CPU Plugin cache for roi_pooling (#9502)
This commit is contained in:
parent
b6d60a2c82
commit
171863e3ce
@ -15,6 +15,7 @@
|
||||
#include "emitters/jit_load_store_emitters.hpp"
|
||||
|
||||
#include <cpu/x64/jit_generator.hpp>
|
||||
#include <common/primitive_hashing_utils.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -149,7 +150,7 @@ private:
|
||||
|
||||
mov(aux_reg_input, reg_input);
|
||||
|
||||
const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
|
||||
const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
|
||||
for (int i = 0; i < c_blocks; i++) {
|
||||
Vmm vmm_max = get_acc_reg(i);
|
||||
|
||||
@ -184,21 +185,21 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
add(aux_reg_input1, jpp_.c_block * jpp_.src_data_size);
|
||||
add(aux_reg_input1, jpp_.c_block * jpp_.src_prc.size());
|
||||
|
||||
inc(w_iter);
|
||||
cmp(w_iter, reg_kw);
|
||||
jl(w_loop_label, T_NEAR);
|
||||
}
|
||||
|
||||
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_data_size);
|
||||
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_prc.size());
|
||||
|
||||
inc(h_iter);
|
||||
cmp(h_iter, reg_kh);
|
||||
jl(h_loop_label, T_NEAR);
|
||||
}
|
||||
|
||||
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
|
||||
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||
for (int i = 0; i < c_blocks; i++) {
|
||||
Vmm vmm_dst = get_acc_reg(i);
|
||||
|
||||
@ -220,7 +221,7 @@ private:
|
||||
Vmm vmm_src11 = get_src_reg(3);
|
||||
|
||||
for (int i = 0; i < c_blocks; i++) {
|
||||
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
|
||||
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
|
||||
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
|
||||
|
||||
mov(aux_reg_input, reg_input);
|
||||
@ -253,7 +254,7 @@ private:
|
||||
uni_vsubps(vmm_src11, vmm_src11, vmm_src01);
|
||||
uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01);
|
||||
|
||||
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
|
||||
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||
|
||||
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
||||
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
|
||||
@ -264,7 +265,7 @@ private:
|
||||
void empty_roi(int c_blocks) {
|
||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||
|
||||
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
|
||||
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||
for (int i = 0; i < c_blocks; i++) {
|
||||
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
||||
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
|
||||
@ -285,8 +286,8 @@ private:
|
||||
roi_pool_bilinear(c_blocks);
|
||||
|
||||
if (isa == cpu::x64::sse41) {
|
||||
add(reg_input, 4 * jpp_.src_data_size);
|
||||
add(reg_output, 4 * jpp_.dst_data_size);
|
||||
add(reg_input, 4 * jpp_.src_prc.size());
|
||||
add(reg_output, 4 * jpp_.dst_prc.size());
|
||||
|
||||
if (jpp_.alg == Algorithm::ROIPoolingMax)
|
||||
roi_pool_max(c_blocks);
|
||||
@ -298,7 +299,7 @@ private:
|
||||
L(empty_roi_label);
|
||||
empty_roi(c_blocks);
|
||||
if (isa == cpu::x64::sse41) {
|
||||
add(reg_output, 4 * jpp_.dst_data_size);
|
||||
add(reg_output, 4 * jpp_.dst_prc.size());
|
||||
empty_roi(c_blocks);
|
||||
}
|
||||
|
||||
@ -306,6 +307,62 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct RoiPoolingKey {
|
||||
jit_roi_pooling_params refParams;
|
||||
|
||||
size_t hash() const;
|
||||
bool operator==(const RoiPoolingKey& rhs) const;
|
||||
};
|
||||
|
||||
size_t RoiPoolingKey::hash() const {
|
||||
using namespace dnnl::impl;
|
||||
using namespace dnnl::impl::primitive_hashing;
|
||||
|
||||
size_t seed = 0;
|
||||
|
||||
seed = hash_combine(seed, refParams.mb);
|
||||
seed = hash_combine(seed, refParams.c);
|
||||
seed = hash_combine(seed, refParams.nb_c);
|
||||
seed = hash_combine(seed, refParams.c_block);
|
||||
seed = hash_combine(seed, refParams.nb_c_blocking);
|
||||
seed = hash_combine(seed, refParams.ih);
|
||||
seed = hash_combine(seed, refParams.iw);
|
||||
seed = hash_combine(seed, refParams.oh);
|
||||
seed = hash_combine(seed, refParams.ow);
|
||||
seed = hash_combine(seed, refParams.alg);
|
||||
seed = hash_combine(seed, refParams.src_prc.getPrecVal());
|
||||
seed = hash_combine(seed, refParams.dst_prc.getPrecVal());
|
||||
seed = hash_combine(seed, refParams.spatial_scale);
|
||||
seed = hash_combine(seed, refParams.pooled_h);
|
||||
seed = hash_combine(seed, refParams.pooled_w);
|
||||
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool RoiPoolingKey::operator==(const RoiPoolingKey &rhs) const {
|
||||
return refParams == rhs.refParams;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool MKLDNNPlugin::jit_roi_pooling_params::operator==(const MKLDNNPlugin::jit_roi_pooling_params &rhs) const noexcept {
|
||||
return mb == rhs.mb &&
|
||||
c == rhs.c &&
|
||||
ih == rhs.ih &&
|
||||
iw == rhs.iw &&
|
||||
oh == rhs.oh &&
|
||||
ow == rhs.ow &&
|
||||
c_block == rhs.c_block &&
|
||||
nb_c == rhs.nb_c &&
|
||||
nb_c_blocking == rhs.nb_c_blocking &&
|
||||
spatial_scale == rhs.spatial_scale &&
|
||||
pooled_h == rhs.pooled_h &&
|
||||
pooled_w == rhs.pooled_w &&
|
||||
src_prc == rhs.src_prc &&
|
||||
dst_prc == rhs.dst_prc &&
|
||||
alg == rhs.alg;
|
||||
}
|
||||
|
||||
bool MKLDNNROIPoolingNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
auto roiPooling = ngraph::as_type_ptr<const ngraph::opset2::ROIPooling>(op);
|
||||
@ -383,8 +440,6 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() {
|
||||
refParams.src_prc = Precision::FP32;
|
||||
}
|
||||
|
||||
src_data_size = dst_data_size = refParams.src_prc.size();
|
||||
|
||||
auto format = mayiuse(avx512_common) ? LayoutType::nCsp16c : LayoutType::nCsp8c;
|
||||
impl_desc_type impl_type;
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
@ -415,8 +470,6 @@ void MKLDNNROIPoolingNode::createPrimitive() {
|
||||
const auto& config = selectedPD->getConfig();
|
||||
refParams.src_prc = config.inConfs[0].desc->getPrecision();
|
||||
refParams.dst_prc = config.outConfs[0].desc->getPrecision();
|
||||
refParams.src_data_size = refParams.src_prc.size();
|
||||
refParams.dst_data_size = refParams.dst_prc.size();
|
||||
|
||||
if (inputShapesDefined()) {
|
||||
if (needPrepareParams() && isExecutable())
|
||||
@ -464,7 +517,13 @@ void MKLDNNROIPoolingNode::prepareParams() {
|
||||
refParams.oh = outDims[2];
|
||||
refParams.ow = outDims[3];
|
||||
|
||||
execPtr = ROIPoolingExecutor::createROIPoolingNewExecutor(refParams);
|
||||
RoiPoolingKey key = {refParams};
|
||||
auto builder = [](const RoiPoolingKey& key) {
|
||||
return ROIPoolingExecutor::createROIPoolingNewExecutor(key.refParams);
|
||||
};
|
||||
auto cache = getRuntimeCache();
|
||||
auto result = cache->getOrCreate(key, builder);
|
||||
execPtr = result.first;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -25,10 +25,10 @@ struct jit_roi_pooling_params {
|
||||
|
||||
InferenceEngine::Precision src_prc;
|
||||
InferenceEngine::Precision dst_prc;
|
||||
int src_data_size = 0;
|
||||
int dst_data_size = 0;
|
||||
|
||||
Algorithm alg;
|
||||
|
||||
bool operator==(const jit_roi_pooling_params& rhs) const noexcept;
|
||||
};
|
||||
|
||||
struct jit_roi_pooling_call_args {
|
||||
@ -83,9 +83,6 @@ private:
|
||||
template<typename T> void execute();
|
||||
template<typename T> struct ROIPoolingExecute;
|
||||
|
||||
size_t src_data_size = 0;
|
||||
size_t dst_data_size = 0;
|
||||
|
||||
jit_roi_pooling_params refParams = {};
|
||||
|
||||
std::string errorPrefix;
|
||||
|
@ -259,7 +259,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
||||
{-1, -1, -1, -1},
|
||||
// static
|
||||
{
|
||||
{3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {1, 3, 8, 8}
|
||||
{3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {3, 4, 50, 50}
|
||||
}
|
||||
},
|
||||
// input 1
|
||||
@ -279,7 +279,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
||||
{-1, {3, 5}, {7, 60}, -1},
|
||||
// static
|
||||
{
|
||||
{3, 4, 50, 50}, {1, 3, 7, 8}, {1, 5, 59, 8}, {3, 5, 60, 8},
|
||||
{3, 4, 50, 50}, {1, 3, 7, 8}, {3, 4, 50, 50}, {1, 3, 7, 8},
|
||||
}
|
||||
},
|
||||
// input 1
|
||||
@ -288,7 +288,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
||||
{{1, 5}, 5},
|
||||
// static
|
||||
{
|
||||
{1, 5}, {3, 5}, {4, 5}, {5, 5}
|
||||
{1, 5}, {2, 5}, {1, 5}, {2, 5}
|
||||
}
|
||||
},
|
||||
},
|
||||
@ -299,7 +299,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
||||
{{1, 8}, {3, 5}, {7, 60}, {5, 50}},
|
||||
// static
|
||||
{
|
||||
{3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {3, 5, 60, 8},
|
||||
{3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {1, 3, 7, 8},
|
||||
}
|
||||
},
|
||||
// input 1
|
||||
@ -308,7 +308,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
||||
{{1, 5}, 5},
|
||||
// static
|
||||
{
|
||||
{1, 5}, {2, 5}, {4, 5}, {5, 5}
|
||||
{1, 5}, {2, 5}, {1, 5}, {2, 5}
|
||||
}
|
||||
},
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user