diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.cpp b/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.cpp index 572247387e3..d6fa2945438 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.cpp +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.cpp @@ -15,6 +15,7 @@ #include "emitters/jit_load_store_emitters.hpp" #include +#include #include #include @@ -149,7 +150,7 @@ private: mov(aux_reg_input, reg_input); - const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size; + const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size(); for (int i = 0; i < c_blocks; i++) { Vmm vmm_max = get_acc_reg(i); @@ -184,21 +185,21 @@ private: } } - add(aux_reg_input1, jpp_.c_block * jpp_.src_data_size); + add(aux_reg_input1, jpp_.c_block * jpp_.src_prc.size()); inc(w_iter); cmp(w_iter, reg_kw); jl(w_loop_label, T_NEAR); } - add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_data_size); + add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_prc.size()); inc(h_iter); cmp(h_iter, reg_kh); jl(h_loop_label, T_NEAR); } - const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size; + const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size(); for (int i = 0; i < c_blocks; i++) { Vmm vmm_dst = get_acc_reg(i); @@ -220,7 +221,7 @@ private: Vmm vmm_src11 = get_src_reg(3); for (int i = 0; i < c_blocks; i++) { - const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size; + const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size(); const auto load_context = std::make_shared(jpp_.src_prc, Precision::FP32, step, src_c_off); mov(aux_reg_input, reg_input); @@ -253,7 +254,7 @@ private: uni_vsubps(vmm_src11, vmm_src11, vmm_src01); uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01); - const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size; + const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size(); store_emitter->emit_code({static_cast(vmm_src11.getIdx())}, {static_cast(reg_output.getIdx())}, std::make_shared(Precision::FP32, jpp_.dst_prc, step, dst_c_off), @@ -264,7 +265,7 @@ private: void empty_roi(int c_blocks) { uni_vpxor(vmm_zero, vmm_zero, vmm_zero); - const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size; + const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size(); for (int i = 0; i < c_blocks; i++) { store_emitter->emit_code({static_cast(vmm_zero.getIdx())}, {static_cast(reg_output.getIdx())}, std::make_shared(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off), @@ -285,8 +286,8 @@ private: roi_pool_bilinear(c_blocks); if (isa == cpu::x64::sse41) { - add(reg_input, 4 * jpp_.src_data_size); - add(reg_output, 4 * jpp_.dst_data_size); + add(reg_input, 4 * jpp_.src_prc.size()); + add(reg_output, 4 * jpp_.dst_prc.size()); if (jpp_.alg == Algorithm::ROIPoolingMax) roi_pool_max(c_blocks); @@ -298,7 +299,7 @@ private: L(empty_roi_label); empty_roi(c_blocks); if (isa == cpu::x64::sse41) { - add(reg_output, 4 * jpp_.dst_data_size); + add(reg_output, 4 * jpp_.dst_prc.size()); empty_roi(c_blocks); } @@ -306,6 +307,62 @@ private: } }; +namespace { +struct RoiPoolingKey { + jit_roi_pooling_params refParams; + + size_t hash() const; + bool operator==(const RoiPoolingKey& rhs) const; +}; + +size_t RoiPoolingKey::hash() const { + using namespace dnnl::impl; + using namespace dnnl::impl::primitive_hashing; + + size_t seed = 0; + + seed = hash_combine(seed, refParams.mb); + seed = hash_combine(seed, refParams.c); + seed = hash_combine(seed, refParams.nb_c); + seed = hash_combine(seed, refParams.c_block); + seed = hash_combine(seed, refParams.nb_c_blocking); + seed = hash_combine(seed, refParams.ih); + seed = hash_combine(seed, refParams.iw); + seed = hash_combine(seed, refParams.oh); + seed = hash_combine(seed, refParams.ow); + seed = hash_combine(seed, refParams.alg); + seed = hash_combine(seed, refParams.src_prc.getPrecVal()); + seed = hash_combine(seed, refParams.dst_prc.getPrecVal()); + seed = hash_combine(seed, refParams.spatial_scale); + seed = hash_combine(seed, refParams.pooled_h); + seed = hash_combine(seed, refParams.pooled_w); + + return seed; +} + +bool RoiPoolingKey::operator==(const RoiPoolingKey &rhs) const { + return refParams == rhs.refParams; +} +} // namespace + +bool MKLDNNPlugin::jit_roi_pooling_params::operator==(const MKLDNNPlugin::jit_roi_pooling_params &rhs) const noexcept { + return mb == rhs.mb && + c == rhs.c && + ih == rhs.ih && + iw == rhs.iw && + oh == rhs.oh && + ow == rhs.ow && + c_block == rhs.c_block && + nb_c == rhs.nb_c && + nb_c_blocking == rhs.nb_c_blocking && + spatial_scale == rhs.spatial_scale && + pooled_h == rhs.pooled_h && + pooled_w == rhs.pooled_w && + src_prc == rhs.src_prc && + dst_prc == rhs.dst_prc && + alg == rhs.alg; +} + bool MKLDNNROIPoolingNode::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { auto roiPooling = ngraph::as_type_ptr(op); @@ -383,8 +440,6 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() { refParams.src_prc = Precision::FP32; } - src_data_size = dst_data_size = refParams.src_prc.size(); - auto format = mayiuse(avx512_common) ? LayoutType::nCsp16c : LayoutType::nCsp8c; impl_desc_type impl_type; if (mayiuse(cpu::x64::avx512_common)) { @@ -415,8 +470,6 @@ void MKLDNNROIPoolingNode::createPrimitive() { const auto& config = selectedPD->getConfig(); refParams.src_prc = config.inConfs[0].desc->getPrecision(); refParams.dst_prc = config.outConfs[0].desc->getPrecision(); - refParams.src_data_size = refParams.src_prc.size(); - refParams.dst_data_size = refParams.dst_prc.size(); if (inputShapesDefined()) { if (needPrepareParams() && isExecutable()) @@ -464,7 +517,13 @@ void MKLDNNROIPoolingNode::prepareParams() { refParams.oh = outDims[2]; refParams.ow = outDims[3]; - execPtr = ROIPoolingExecutor::createROIPoolingNewExecutor(refParams); + RoiPoolingKey key = {refParams}; + auto builder = [](const RoiPoolingKey& key) { + return ROIPoolingExecutor::createROIPoolingNewExecutor(key.refParams); + }; + auto cache = getRuntimeCache(); + auto result = cache->getOrCreate(key, builder); + execPtr = result.first; } template diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.h b/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.h index 3e8a74d72b2..7c3db364fd7 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.h +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_roi_pooling_node.h @@ -25,10 +25,10 @@ struct jit_roi_pooling_params { InferenceEngine::Precision src_prc; InferenceEngine::Precision dst_prc; - int src_data_size = 0; - int dst_data_size = 0; Algorithm alg; + + bool operator==(const jit_roi_pooling_params& rhs) const noexcept; }; struct jit_roi_pooling_call_args { @@ -83,9 +83,6 @@ private: template void execute(); template struct ROIPoolingExecute; - size_t src_data_size = 0; - size_t dst_data_size = 0; - jit_roi_pooling_params refParams = {}; std::string errorPrefix; diff --git a/src/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp b/src/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp index 5b3955f12e3..1aaf7ab7e29 100644 --- a/src/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp +++ b/src/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp @@ -259,7 +259,7 @@ const std::vector inShapes = { {-1, -1, -1, -1}, // static { - {3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {1, 3, 8, 8} + {3, 4, 50, 50}, {3, 4, 50, 50}, {3, 4, 50, 50}, {1, 3, 8, 8}, {1, 3, 8, 8}, {3, 4, 50, 50} } }, // input 1 @@ -279,7 +279,7 @@ const std::vector inShapes = { {-1, {3, 5}, {7, 60}, -1}, // static { - {3, 4, 50, 50}, {1, 3, 7, 8}, {1, 5, 59, 8}, {3, 5, 60, 8}, + {3, 4, 50, 50}, {1, 3, 7, 8}, {3, 4, 50, 50}, {1, 3, 7, 8}, } }, // input 1 @@ -288,7 +288,7 @@ const std::vector inShapes = { {{1, 5}, 5}, // static { - {1, 5}, {3, 5}, {4, 5}, {5, 5} + {1, 5}, {2, 5}, {1, 5}, {2, 5} } }, }, @@ -299,7 +299,7 @@ const std::vector inShapes = { {{1, 8}, {3, 5}, {7, 60}, {5, 50}}, // static { - {3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {3, 5, 60, 8}, + {3, 4, 50, 50}, {1, 3, 7, 8}, {8, 5, 59, 5}, {1, 3, 7, 8}, } }, // input 1 @@ -308,7 +308,7 @@ const std::vector inShapes = { {{1, 5}, 5}, // static { - {1, 5}, {2, 5}, {4, 5}, {5, 5} + {1, 5}, {2, 5}, {1, 5}, {2, 5} } }, },