[CPU] Enable CPU Plugin cache for roi_pooling (#9502)

This commit is contained in:
Mang Guo 2022-01-12 20:21:51 +08:00 committed by GitHub
parent b6d60a2c82
commit 171863e3ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 25 deletions

View File

@ -15,6 +15,7 @@
#include "emitters/jit_load_store_emitters.hpp"
#include <cpu/x64/jit_generator.hpp>
#include <common/primitive_hashing_utils.hpp>
#include <string>
#include <vector>
@ -149,7 +150,7 @@ private:
mov(aux_reg_input, reg_input);
const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_max = get_acc_reg(i);
@ -184,21 +185,21 @@ private:
}
}
add(aux_reg_input1, jpp_.c_block * jpp_.src_data_size);
add(aux_reg_input1, jpp_.c_block * jpp_.src_prc.size());
inc(w_iter);
cmp(w_iter, reg_kw);
jl(w_loop_label, T_NEAR);
}
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_data_size);
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_prc.size());
inc(h_iter);
cmp(h_iter, reg_kh);
jl(h_loop_label, T_NEAR);
}
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_dst = get_acc_reg(i);
@ -220,7 +221,7 @@ private:
Vmm vmm_src11 = get_src_reg(3);
for (int i = 0; i < c_blocks; i++) {
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
mov(aux_reg_input, reg_input);
@ -253,7 +254,7 @@ private:
uni_vsubps(vmm_src11, vmm_src11, vmm_src01);
uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01);
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
@ -264,7 +265,7 @@ private:
void empty_roi(int c_blocks) {
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
for (int i = 0; i < c_blocks; i++) {
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
@ -285,8 +286,8 @@ private:
roi_pool_bilinear(c_blocks);
if (isa == cpu::x64::sse41) {
add(reg_input, 4 * jpp_.src_data_size);
add(reg_output, 4 * jpp_.dst_data_size);
add(reg_input, 4 * jpp_.src_prc.size());
add(reg_output, 4 * jpp_.dst_prc.size());
if (jpp_.alg == Algorithm::ROIPoolingMax)
roi_pool_max(c_blocks);
@ -298,7 +299,7 @@ private:
L(empty_roi_label);
empty_roi(c_blocks);
if (isa == cpu::x64::sse41) {
add(reg_output, 4 * jpp_.dst_data_size);
add(reg_output, 4 * jpp_.dst_prc.size());
empty_roi(c_blocks);
}
@ -306,6 +307,62 @@ private:
}
};
namespace {
struct RoiPoolingKey {
jit_roi_pooling_params refParams;
size_t hash() const;
bool operator==(const RoiPoolingKey& rhs) const;
};
size_t RoiPoolingKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, refParams.mb);
seed = hash_combine(seed, refParams.c);
seed = hash_combine(seed, refParams.nb_c);
seed = hash_combine(seed, refParams.c_block);
seed = hash_combine(seed, refParams.nb_c_blocking);
seed = hash_combine(seed, refParams.ih);
seed = hash_combine(seed, refParams.iw);
seed = hash_combine(seed, refParams.oh);
seed = hash_combine(seed, refParams.ow);
seed = hash_combine(seed, refParams.alg);
seed = hash_combine(seed, refParams.src_prc.getPrecVal());
seed = hash_combine(seed, refParams.dst_prc.getPrecVal());
seed = hash_combine(seed, refParams.spatial_scale);
seed = hash_combine(seed, refParams.pooled_h);
seed = hash_combine(seed, refParams.pooled_w);
return seed;
}
bool RoiPoolingKey::operator==(const RoiPoolingKey &rhs) const {
return refParams == rhs.refParams;
}
} // namespace
bool MKLDNNPlugin::jit_roi_pooling_params::operator==(const MKLDNNPlugin::jit_roi_pooling_params &rhs) const noexcept {
return mb == rhs.mb &&
c == rhs.c &&
ih == rhs.ih &&
iw == rhs.iw &&
oh == rhs.oh &&
ow == rhs.ow &&
c_block == rhs.c_block &&
nb_c == rhs.nb_c &&
nb_c_blocking == rhs.nb_c_blocking &&
spatial_scale == rhs.spatial_scale &&
pooled_h == rhs.pooled_h &&
pooled_w == rhs.pooled_w &&
src_prc == rhs.src_prc &&
dst_prc == rhs.dst_prc &&
alg == rhs.alg;
}
bool MKLDNNROIPoolingNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
auto roiPooling = ngraph::as_type_ptr<const ngraph::opset2::ROIPooling>(op);
@ -383,8 +440,6 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() {
refParams.src_prc = Precision::FP32;
}
src_data_size = dst_data_size = refParams.src_prc.size();
auto format = mayiuse(avx512_common) ? LayoutType::nCsp16c : LayoutType::nCsp8c;
impl_desc_type impl_type;
if (mayiuse(cpu::x64::avx512_common)) {
@ -415,8 +470,6 @@ void MKLDNNROIPoolingNode::createPrimitive() {
const auto& config = selectedPD->getConfig();
refParams.src_prc = config.inConfs[0].desc->getPrecision();
refParams.dst_prc = config.outConfs[0].desc->getPrecision();
refParams.src_data_size = refParams.src_prc.size();
refParams.dst_data_size = refParams.dst_prc.size();
if (inputShapesDefined()) {
if (needPrepareParams() && isExecutable())
@ -464,7 +517,13 @@ void MKLDNNROIPoolingNode::prepareParams() {
refParams.oh = outDims[2];
refParams.ow = outDims[3];
execPtr = ROIPoolingExecutor::createROIPoolingNewExecutor(refParams);
RoiPoolingKey key = {refParams};
auto builder = [](const RoiPoolingKey& key) {
return ROIPoolingExecutor::createROIPoolingNewExecutor(key.refParams);
};
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, builder);
execPtr = result.first;
}
template <typename T>

View File

@ -25,10 +25,10 @@ struct jit_roi_pooling_params {
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
int src_data_size = 0;
int dst_data_size = 0;
Algorithm alg;
bool operator==(const jit_roi_pooling_params& rhs) const noexcept;
};
struct jit_roi_pooling_call_args {
@ -83,9 +83,6 @@ private:
template<typename T> void execute();
template<typename T> struct ROIPoolingExecute;
size_t src_data_size = 0;
size_t dst_data_size = 0;
jit_roi_pooling_params refParams = {};
std::string errorPrefix;

View File

@ -259,7 +259,7 @@ const std::vector<roiPoolingShapes> inShapes = {
{-1, -1, -1, -1},
// static
{
{3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {1, 3, 8, 8}
{3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {3, 4, 50, 50}
}
},
// input 1
@ -279,7 +279,7 @@ const std::vector<roiPoolingShapes> inShapes = {
{-1, {3, 5}, {7, 60}, -1},
// static
{
{3, 4, 50, 50}, {1, 3, 7, 8}, {1, 5, 59, 8}, {3, 5, 60, 8},
{3, 4, 50, 50}, {1, 3, 7, 8}, {3, 4, 50, 50}, {1, 3, 7, 8},
}
},
// input 1
@ -288,7 +288,7 @@ const std::vector<roiPoolingShapes> inShapes = {
{{1, 5}, 5},
// static
{
{1, 5}, {3, 5}, {4, 5}, {5, 5}
{1, 5}, {2, 5}, {1, 5}, {2, 5}
}
},
},
@ -299,7 +299,7 @@ const std::vector<roiPoolingShapes> inShapes = {
{{1, 8}, {3, 5}, {7, 60}, {5, 50}},
// static
{
{3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {3, 5, 60, 8},
{3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {1, 3, 7, 8},
}
},
// input 1
@ -308,7 +308,7 @@ const std::vector<roiPoolingShapes> inShapes = {
{{1, 5}, 5},
// static
{
{1, 5}, {2, 5}, {4, 5}, {5, 5}
{1, 5}, {2, 5}, {1, 5}, {2, 5}
}
},
},