From 5b847fabe842a9124efbfccb39655e73f4dac19b Mon Sep 17 00:00:00 2001 From: Egor Duplensky Date: Wed, 16 Jun 2021 14:42:26 +0300 Subject: [PATCH] [CPU] Refactor load_emitter_context semantic. Update store emitter (#5824) * [CPU] Refactor load_emitter_context semantic. Update store emitter New load_emitter_context constructor arguments order seems to be more convenient. Store emitter now emits bf16 emu. --- .../mkldnn_plugin/emitters/jit_load_store_emitters.cpp | 5 +++++ .../mkldnn_plugin/emitters/jit_load_store_emitters.hpp | 6 ++++-- .../src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp | 7 +++---- .../src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp | 9 ++++----- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.cpp b/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.cpp index 83bc04c530d..4d1e3819394 100644 --- a/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.cpp +++ b/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.cpp @@ -510,6 +510,11 @@ size_t jit_store_emitter::aux_vecs_count() const { size_t jit_store_emitter::get_inputs_num() const { return 1; } +void jit_store_emitter::emit_data() const { + if (emu_vcvtneps2bf16) + emu_vcvtneps2bf16->emit_data(); +} + void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, const emitter_context *emit_context) const { diff --git a/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.hpp b/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.hpp index 00c2e49262d..ec863d0c69e 100644 --- a/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.hpp +++ b/inference-engine/src/mkldnn_plugin/emitters/jit_load_store_emitters.hpp @@ -18,8 +18,8 @@ struct load_emitter_context : public emitter_context { load_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), load_num_(8), offset_byte_(0), is_fill_(false), fill_value_("zero") {} - load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero", int offset_byte = 0): - src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value), offset_byte_(offset_byte) {} + load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, int offset_byte = 0, bool is_fill = false, std::string fill_value = "zero"): + src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), offset_byte_(offset_byte), is_fill_(is_fill), fill_value_(fill_value) {} int offset_byte_; int load_num_; @@ -124,6 +124,8 @@ public: size_t get_inputs_num() const override; + void emit_data() const override; + std::shared_ptr get_emu_vcvtneps2bf16() const { return emu_vcvtneps2bf16; } diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp index f27a40e3bd2..baff79e5d75 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp @@ -306,8 +306,8 @@ private: inline void worker_tail_planar() { Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, dst_prc, tail_num, true, "zero"), - {}, {load_pool_gpr_idxs}); + std::make_shared(jcp_.src_prc, dst_prc, tail_num, 0, true), + {}, {load_pool_gpr_idxs}); if (jcp_.normalize_variance) { if (!isFloatCompatible(jcp_.src_prc)) @@ -477,8 +477,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator this->postamble(); load_emitter->emit_data(); - if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr) - store_emitter->get_emu_vcvtneps2bf16()->emit_data(); + store_emitter->emit_data(); for (auto& inj : eltwise_injectors) inj->prepare_table(); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp index 77db7621692..a1a7f8329a5 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp @@ -88,8 +88,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi this->postamble(); load_emitter->emit_data(); - if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr) - store_emitter->get_emu_vcvtneps2bf16()->emit_data(); + store_emitter->emit_data(); } private: @@ -155,7 +154,7 @@ private: Vmm vmm_max = get_acc_reg(i); load_emitter->emit_code({static_cast(reg_input.getIdx())}, {static_cast(vmm_max.getIdx())}, - std::make_shared(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off), + std::make_shared(jpp_.src_prc, Precision::FP32, step, i * src_c_off), {}, load_pool_gpr_idxs); } @@ -169,7 +168,7 @@ private: Vmm vmm_src = get_src_reg(i); load_emitter->emit_code({static_cast(aux_reg_input1.getIdx())}, {static_cast(vmm_src.getIdx())}, - std::make_shared(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off), + std::make_shared(jpp_.src_prc, Precision::FP32, step, i * src_c_off), {}, load_pool_gpr_idxs); if (isa == cpu::x64::sse41) { @@ -222,7 +221,7 @@ private: for (int i = 0; i < c_blocks; i++) { const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size; - const auto load_context = std::make_shared(jpp_.src_prc, Precision::FP32, step, false, "zero", src_c_off); + const auto load_context = std::make_shared(jpp_.src_prc, Precision::FP32, step, src_c_off); mov(aux_reg_input, reg_input);