[CPU] Enable CPU Plugin cache for roi_pooling (#9502)
This commit is contained in:
parent
b6d60a2c82
commit
171863e3ce
@ -15,6 +15,7 @@
|
|||||||
#include "emitters/jit_load_store_emitters.hpp"
|
#include "emitters/jit_load_store_emitters.hpp"
|
||||||
|
|
||||||
#include <cpu/x64/jit_generator.hpp>
|
#include <cpu/x64/jit_generator.hpp>
|
||||||
|
#include <common/primitive_hashing_utils.hpp>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -149,7 +150,7 @@ private:
|
|||||||
|
|
||||||
mov(aux_reg_input, reg_input);
|
mov(aux_reg_input, reg_input);
|
||||||
|
|
||||||
const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
|
const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
|
||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
Vmm vmm_max = get_acc_reg(i);
|
Vmm vmm_max = get_acc_reg(i);
|
||||||
|
|
||||||
@ -184,21 +185,21 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
add(aux_reg_input1, jpp_.c_block * jpp_.src_data_size);
|
add(aux_reg_input1, jpp_.c_block * jpp_.src_prc.size());
|
||||||
|
|
||||||
inc(w_iter);
|
inc(w_iter);
|
||||||
cmp(w_iter, reg_kw);
|
cmp(w_iter, reg_kw);
|
||||||
jl(w_loop_label, T_NEAR);
|
jl(w_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
|
|
||||||
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_data_size);
|
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_prc.size());
|
||||||
|
|
||||||
inc(h_iter);
|
inc(h_iter);
|
||||||
cmp(h_iter, reg_kh);
|
cmp(h_iter, reg_kh);
|
||||||
jl(h_loop_label, T_NEAR);
|
jl(h_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
|
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
Vmm vmm_dst = get_acc_reg(i);
|
Vmm vmm_dst = get_acc_reg(i);
|
||||||
|
|
||||||
@ -220,7 +221,7 @@ private:
|
|||||||
Vmm vmm_src11 = get_src_reg(3);
|
Vmm vmm_src11 = get_src_reg(3);
|
||||||
|
|
||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
|
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
|
||||||
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
|
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
|
||||||
|
|
||||||
mov(aux_reg_input, reg_input);
|
mov(aux_reg_input, reg_input);
|
||||||
@ -253,7 +254,7 @@ private:
|
|||||||
uni_vsubps(vmm_src11, vmm_src11, vmm_src01);
|
uni_vsubps(vmm_src11, vmm_src11, vmm_src01);
|
||||||
uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01);
|
uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01);
|
||||||
|
|
||||||
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
|
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||||
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
|
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
|
||||||
@ -264,7 +265,7 @@ private:
|
|||||||
void empty_roi(int c_blocks) {
|
void empty_roi(int c_blocks) {
|
||||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||||
|
|
||||||
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
|
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
||||||
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
|
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
|
||||||
@ -285,8 +286,8 @@ private:
|
|||||||
roi_pool_bilinear(c_blocks);
|
roi_pool_bilinear(c_blocks);
|
||||||
|
|
||||||
if (isa == cpu::x64::sse41) {
|
if (isa == cpu::x64::sse41) {
|
||||||
add(reg_input, 4 * jpp_.src_data_size);
|
add(reg_input, 4 * jpp_.src_prc.size());
|
||||||
add(reg_output, 4 * jpp_.dst_data_size);
|
add(reg_output, 4 * jpp_.dst_prc.size());
|
||||||
|
|
||||||
if (jpp_.alg == Algorithm::ROIPoolingMax)
|
if (jpp_.alg == Algorithm::ROIPoolingMax)
|
||||||
roi_pool_max(c_blocks);
|
roi_pool_max(c_blocks);
|
||||||
@ -298,7 +299,7 @@ private:
|
|||||||
L(empty_roi_label);
|
L(empty_roi_label);
|
||||||
empty_roi(c_blocks);
|
empty_roi(c_blocks);
|
||||||
if (isa == cpu::x64::sse41) {
|
if (isa == cpu::x64::sse41) {
|
||||||
add(reg_output, 4 * jpp_.dst_data_size);
|
add(reg_output, 4 * jpp_.dst_prc.size());
|
||||||
empty_roi(c_blocks);
|
empty_roi(c_blocks);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,6 +307,62 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct RoiPoolingKey {
|
||||||
|
jit_roi_pooling_params refParams;
|
||||||
|
|
||||||
|
size_t hash() const;
|
||||||
|
bool operator==(const RoiPoolingKey& rhs) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t RoiPoolingKey::hash() const {
|
||||||
|
using namespace dnnl::impl;
|
||||||
|
using namespace dnnl::impl::primitive_hashing;
|
||||||
|
|
||||||
|
size_t seed = 0;
|
||||||
|
|
||||||
|
seed = hash_combine(seed, refParams.mb);
|
||||||
|
seed = hash_combine(seed, refParams.c);
|
||||||
|
seed = hash_combine(seed, refParams.nb_c);
|
||||||
|
seed = hash_combine(seed, refParams.c_block);
|
||||||
|
seed = hash_combine(seed, refParams.nb_c_blocking);
|
||||||
|
seed = hash_combine(seed, refParams.ih);
|
||||||
|
seed = hash_combine(seed, refParams.iw);
|
||||||
|
seed = hash_combine(seed, refParams.oh);
|
||||||
|
seed = hash_combine(seed, refParams.ow);
|
||||||
|
seed = hash_combine(seed, refParams.alg);
|
||||||
|
seed = hash_combine(seed, refParams.src_prc.getPrecVal());
|
||||||
|
seed = hash_combine(seed, refParams.dst_prc.getPrecVal());
|
||||||
|
seed = hash_combine(seed, refParams.spatial_scale);
|
||||||
|
seed = hash_combine(seed, refParams.pooled_h);
|
||||||
|
seed = hash_combine(seed, refParams.pooled_w);
|
||||||
|
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RoiPoolingKey::operator==(const RoiPoolingKey &rhs) const {
|
||||||
|
return refParams == rhs.refParams;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool MKLDNNPlugin::jit_roi_pooling_params::operator==(const MKLDNNPlugin::jit_roi_pooling_params &rhs) const noexcept {
|
||||||
|
return mb == rhs.mb &&
|
||||||
|
c == rhs.c &&
|
||||||
|
ih == rhs.ih &&
|
||||||
|
iw == rhs.iw &&
|
||||||
|
oh == rhs.oh &&
|
||||||
|
ow == rhs.ow &&
|
||||||
|
c_block == rhs.c_block &&
|
||||||
|
nb_c == rhs.nb_c &&
|
||||||
|
nb_c_blocking == rhs.nb_c_blocking &&
|
||||||
|
spatial_scale == rhs.spatial_scale &&
|
||||||
|
pooled_h == rhs.pooled_h &&
|
||||||
|
pooled_w == rhs.pooled_w &&
|
||||||
|
src_prc == rhs.src_prc &&
|
||||||
|
dst_prc == rhs.dst_prc &&
|
||||||
|
alg == rhs.alg;
|
||||||
|
}
|
||||||
|
|
||||||
bool MKLDNNROIPoolingNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
bool MKLDNNROIPoolingNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||||
try {
|
try {
|
||||||
auto roiPooling = ngraph::as_type_ptr<const ngraph::opset2::ROIPooling>(op);
|
auto roiPooling = ngraph::as_type_ptr<const ngraph::opset2::ROIPooling>(op);
|
||||||
@ -383,8 +440,6 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() {
|
|||||||
refParams.src_prc = Precision::FP32;
|
refParams.src_prc = Precision::FP32;
|
||||||
}
|
}
|
||||||
|
|
||||||
src_data_size = dst_data_size = refParams.src_prc.size();
|
|
||||||
|
|
||||||
auto format = mayiuse(avx512_common) ? LayoutType::nCsp16c : LayoutType::nCsp8c;
|
auto format = mayiuse(avx512_common) ? LayoutType::nCsp16c : LayoutType::nCsp8c;
|
||||||
impl_desc_type impl_type;
|
impl_desc_type impl_type;
|
||||||
if (mayiuse(cpu::x64::avx512_common)) {
|
if (mayiuse(cpu::x64::avx512_common)) {
|
||||||
@ -415,8 +470,6 @@ void MKLDNNROIPoolingNode::createPrimitive() {
|
|||||||
const auto& config = selectedPD->getConfig();
|
const auto& config = selectedPD->getConfig();
|
||||||
refParams.src_prc = config.inConfs[0].desc->getPrecision();
|
refParams.src_prc = config.inConfs[0].desc->getPrecision();
|
||||||
refParams.dst_prc = config.outConfs[0].desc->getPrecision();
|
refParams.dst_prc = config.outConfs[0].desc->getPrecision();
|
||||||
refParams.src_data_size = refParams.src_prc.size();
|
|
||||||
refParams.dst_data_size = refParams.dst_prc.size();
|
|
||||||
|
|
||||||
if (inputShapesDefined()) {
|
if (inputShapesDefined()) {
|
||||||
if (needPrepareParams() && isExecutable())
|
if (needPrepareParams() && isExecutable())
|
||||||
@ -464,7 +517,13 @@ void MKLDNNROIPoolingNode::prepareParams() {
|
|||||||
refParams.oh = outDims[2];
|
refParams.oh = outDims[2];
|
||||||
refParams.ow = outDims[3];
|
refParams.ow = outDims[3];
|
||||||
|
|
||||||
execPtr = ROIPoolingExecutor::createROIPoolingNewExecutor(refParams);
|
RoiPoolingKey key = {refParams};
|
||||||
|
auto builder = [](const RoiPoolingKey& key) {
|
||||||
|
return ROIPoolingExecutor::createROIPoolingNewExecutor(key.refParams);
|
||||||
|
};
|
||||||
|
auto cache = getRuntimeCache();
|
||||||
|
auto result = cache->getOrCreate(key, builder);
|
||||||
|
execPtr = result.first;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -25,10 +25,10 @@ struct jit_roi_pooling_params {
|
|||||||
|
|
||||||
InferenceEngine::Precision src_prc;
|
InferenceEngine::Precision src_prc;
|
||||||
InferenceEngine::Precision dst_prc;
|
InferenceEngine::Precision dst_prc;
|
||||||
int src_data_size = 0;
|
|
||||||
int dst_data_size = 0;
|
|
||||||
|
|
||||||
Algorithm alg;
|
Algorithm alg;
|
||||||
|
|
||||||
|
bool operator==(const jit_roi_pooling_params& rhs) const noexcept;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct jit_roi_pooling_call_args {
|
struct jit_roi_pooling_call_args {
|
||||||
@ -83,9 +83,6 @@ private:
|
|||||||
template<typename T> void execute();
|
template<typename T> void execute();
|
||||||
template<typename T> struct ROIPoolingExecute;
|
template<typename T> struct ROIPoolingExecute;
|
||||||
|
|
||||||
size_t src_data_size = 0;
|
|
||||||
size_t dst_data_size = 0;
|
|
||||||
|
|
||||||
jit_roi_pooling_params refParams = {};
|
jit_roi_pooling_params refParams = {};
|
||||||
|
|
||||||
std::string errorPrefix;
|
std::string errorPrefix;
|
||||||
|
@ -259,7 +259,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
|||||||
{-1, -1, -1, -1},
|
{-1, -1, -1, -1},
|
||||||
// static
|
// static
|
||||||
{
|
{
|
||||||
{3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {1, 3, 8, 8}
|
{3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {3, 4, 50, 50}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// input 1
|
// input 1
|
||||||
@ -279,7 +279,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
|||||||
{-1, {3, 5}, {7, 60}, -1},
|
{-1, {3, 5}, {7, 60}, -1},
|
||||||
// static
|
// static
|
||||||
{
|
{
|
||||||
{3, 4, 50, 50}, {1, 3, 7, 8}, {1, 5, 59, 8}, {3, 5, 60, 8},
|
{3, 4, 50, 50}, {1, 3, 7, 8}, {3, 4, 50, 50}, {1, 3, 7, 8},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// input 1
|
// input 1
|
||||||
@ -288,7 +288,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
|||||||
{{1, 5}, 5},
|
{{1, 5}, 5},
|
||||||
// static
|
// static
|
||||||
{
|
{
|
||||||
{1, 5}, {3, 5}, {4, 5}, {5, 5}
|
{1, 5}, {2, 5}, {1, 5}, {2, 5}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -299,7 +299,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
|||||||
{{1, 8}, {3, 5}, {7, 60}, {5, 50}},
|
{{1, 8}, {3, 5}, {7, 60}, {5, 50}},
|
||||||
// static
|
// static
|
||||||
{
|
{
|
||||||
{3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {3, 5, 60, 8},
|
{3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {1, 3, 7, 8},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// input 1
|
// input 1
|
||||||
@ -308,7 +308,7 @@ const std::vector<roiPoolingShapes> inShapes = {
|
|||||||
{{1, 5}, 5},
|
{{1, 5}, 5},
|
||||||
// static
|
// static
|
||||||
{
|
{
|
||||||
{1, 5}, {2, 5}, {4, 5}, {5, 5}
|
{1, 5}, {2, 5}, {1, 5}, {2, 5}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user