[CPU] Refactor load_emitter_context semantic. Update store emitter (#5824)
* [CPU] Refactor load_emitter_context semantic. Update store emitter New load_emitter_context constructor arguments order seems to be more convenient. Store emitter now emits bf16 emu.
This commit is contained in:
parent
b05977a536
commit
5b847fabe8
@ -510,6 +510,11 @@ size_t jit_store_emitter::aux_vecs_count() const {
|
||||
|
||||
size_t jit_store_emitter::get_inputs_num() const { return 1; }
|
||||
|
||||
void jit_store_emitter::emit_data() const {
|
||||
if (emu_vcvtneps2bf16)
|
||||
emu_vcvtneps2bf16->emit_data();
|
||||
}
|
||||
|
||||
void jit_store_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
|
||||
const emitter_context *emit_context) const {
|
||||
|
@ -18,8 +18,8 @@ struct load_emitter_context : public emitter_context {
|
||||
load_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), load_num_(8),
|
||||
offset_byte_(0), is_fill_(false), fill_value_("zero") {}
|
||||
|
||||
load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero", int offset_byte = 0):
|
||||
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value), offset_byte_(offset_byte) {}
|
||||
load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, int offset_byte = 0, bool is_fill = false, std::string fill_value = "zero"):
|
||||
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), offset_byte_(offset_byte), is_fill_(is_fill), fill_value_(fill_value) {}
|
||||
|
||||
int offset_byte_;
|
||||
int load_num_;
|
||||
@ -124,6 +124,8 @@ public:
|
||||
|
||||
size_t get_inputs_num() const override;
|
||||
|
||||
void emit_data() const override;
|
||||
|
||||
std::shared_ptr<jit_emu_vcvtneps2bf16> get_emu_vcvtneps2bf16() const {
|
||||
return emu_vcvtneps2bf16;
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ private:
|
||||
inline void worker_tail_planar() {
|
||||
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
|
||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
||||
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num, true, "zero"),
|
||||
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num, 0, true),
|
||||
{}, {load_pool_gpr_idxs});
|
||||
|
||||
if (jcp_.normalize_variance) {
|
||||
@ -477,8 +477,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
|
||||
this->postamble();
|
||||
|
||||
load_emitter->emit_data();
|
||||
if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr)
|
||||
store_emitter->get_emu_vcvtneps2bf16()->emit_data();
|
||||
store_emitter->emit_data();
|
||||
|
||||
for (auto& inj : eltwise_injectors)
|
||||
inj->prepare_table();
|
||||
|
@ -88,8 +88,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
|
||||
this->postamble();
|
||||
|
||||
load_emitter->emit_data();
|
||||
if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr)
|
||||
store_emitter->get_emu_vcvtneps2bf16()->emit_data();
|
||||
store_emitter->emit_data();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -155,7 +154,7 @@ private:
|
||||
Vmm vmm_max = get_acc_reg(i);
|
||||
|
||||
load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx())}, {static_cast<size_t>(vmm_max.getIdx())},
|
||||
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
|
||||
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
|
||||
{}, load_pool_gpr_idxs);
|
||||
}
|
||||
|
||||
@ -169,7 +168,7 @@ private:
|
||||
Vmm vmm_src = get_src_reg(i);
|
||||
|
||||
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
|
||||
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
|
||||
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
|
||||
{}, load_pool_gpr_idxs);
|
||||
|
||||
if (isa == cpu::x64::sse41) {
|
||||
@ -222,7 +221,7 @@ private:
|
||||
|
||||
for (int i = 0; i < c_blocks; i++) {
|
||||
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
|
||||
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", src_c_off);
|
||||
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
|
||||
|
||||
mov(aux_reg_input, reg_input);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user