[CPU] Refactor load_emitter_context semantic. Update store emitter (#5824)

* [CPU] Refactor load_emitter_context semantic. Update store emitter

New load_emitter_context constructor arguments order
seems to be more convenient.
Store emitter now emits bf16 emu.
This commit is contained in:
Egor Duplensky 2021-06-16 14:42:26 +03:00 committed by GitHub
parent b05977a536
commit 5b847fabe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 11 deletions

View File

@ -510,6 +510,11 @@ size_t jit_store_emitter::aux_vecs_count() const {
size_t jit_store_emitter::get_inputs_num() const { return 1; }
void jit_store_emitter::emit_data() const {
if (emu_vcvtneps2bf16)
emu_vcvtneps2bf16->emit_data();
}
void jit_store_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) const {

View File

@ -18,8 +18,8 @@ struct load_emitter_context : public emitter_context {
load_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), load_num_(8),
offset_byte_(0), is_fill_(false), fill_value_("zero") {}
load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero", int offset_byte = 0):
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value), offset_byte_(offset_byte) {}
load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, int offset_byte = 0, bool is_fill = false, std::string fill_value = "zero"):
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), offset_byte_(offset_byte), is_fill_(is_fill), fill_value_(fill_value) {}
int offset_byte_;
int load_num_;
@ -124,6 +124,8 @@ public:
size_t get_inputs_num() const override;
void emit_data() const override;
std::shared_ptr<jit_emu_vcvtneps2bf16> get_emu_vcvtneps2bf16() const {
return emu_vcvtneps2bf16;
}

View File

@ -306,8 +306,8 @@ private:
inline void worker_tail_planar() {
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num, true, "zero"),
{}, {load_pool_gpr_idxs});
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num, 0, true),
{}, {load_pool_gpr_idxs});
if (jcp_.normalize_variance) {
if (!isFloatCompatible(jcp_.src_prc))
@ -477,8 +477,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
this->postamble();
load_emitter->emit_data();
if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr)
store_emitter->get_emu_vcvtneps2bf16()->emit_data();
store_emitter->emit_data();
for (auto& inj : eltwise_injectors)
inj->prepare_table();

View File

@ -88,8 +88,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
this->postamble();
load_emitter->emit_data();
if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr)
store_emitter->get_emu_vcvtneps2bf16()->emit_data();
store_emitter->emit_data();
}
private:
@ -155,7 +154,7 @@ private:
Vmm vmm_max = get_acc_reg(i);
load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx())}, {static_cast<size_t>(vmm_max.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
{}, load_pool_gpr_idxs);
}
@ -169,7 +168,7 @@ private:
Vmm vmm_src = get_src_reg(i);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
{}, load_pool_gpr_idxs);
if (isa == cpu::x64::sse41) {
@ -222,7 +221,7 @@ private:
for (int i = 0; i < c_blocks; i++) {
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", src_c_off);
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
mov(aux_reg_input, reg_input);