[CPU] Removed Contexts from load and store emitters (#12446)
This commit is contained in:
parent
97f3d84cf5
commit
6d6f52806b
@ -23,6 +23,11 @@ enum emitter_in_out_map {
|
|||||||
gpr_to_gpr,
|
gpr_to_gpr,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// structure for storage of emitter parameters to hash in map
|
||||||
|
struct emitter_params {
|
||||||
|
virtual size_t hash() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
struct emitter_context {
|
struct emitter_context {
|
||||||
virtual ~emitter_context() = default;
|
virtual ~emitter_context() = default;
|
||||||
};
|
};
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -15,40 +15,37 @@ using namespace InferenceEngine;
|
|||||||
namespace ov {
|
namespace ov {
|
||||||
namespace intel_cpu {
|
namespace intel_cpu {
|
||||||
|
|
||||||
struct load_emitter_context : public emitter_context {
|
struct load_emitter_params : public emitter_params {
|
||||||
load_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), load_num_(8),
|
load_emitter_params(Precision src_prc, Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero"):
|
||||||
offset_byte_(0), is_fill_(false), fill_value_("zero") {}
|
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value) {}
|
||||||
|
|
||||||
load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, int offset_byte = 0, bool is_fill = false, std::string fill_value = "zero"):
|
size_t hash() const override;
|
||||||
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), offset_byte_(offset_byte), is_fill_(is_fill), fill_value_(fill_value) {}
|
|
||||||
|
|
||||||
int offset_byte_;
|
|
||||||
int load_num_;
|
|
||||||
Precision src_prc_;
|
Precision src_prc_;
|
||||||
Precision dst_prc_;
|
Precision dst_prc_;
|
||||||
|
int load_num_;
|
||||||
bool is_fill_;
|
bool is_fill_;
|
||||||
std::string fill_value_;
|
std::string fill_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct store_emitter_context : public emitter_context {
|
struct store_emitter_params : public emitter_params {
|
||||||
store_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32),
|
store_emitter_params(Precision src_prc, Precision dst_prc, int store_num):
|
||||||
store_num_(8), offset_byte_(0) {}
|
src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num) {}
|
||||||
|
|
||||||
store_emitter_context(Precision src_prc, Precision dst_prc, int store_num, int offset_byte = 0)
|
size_t hash() const override;
|
||||||
: src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num), offset_byte_(offset_byte) {}
|
|
||||||
|
|
||||||
int offset_byte_;
|
|
||||||
int store_num_;
|
|
||||||
Precision src_prc_;
|
Precision src_prc_;
|
||||||
Precision dst_prc_;
|
Precision dst_prc_;
|
||||||
|
int store_num_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class jit_load_emitter : public jit_emitter {
|
class jit_load_emitter : public jit_emitter {
|
||||||
public:
|
public:
|
||||||
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
|
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, Precision src_prc, Precision dst_prc, int load_num,
|
||||||
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec);
|
Precision exec_prc = Precision::FP32, bool is_fill = false, std::string fill_value = "zero",
|
||||||
|
emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec);
|
||||||
/**
|
/**
|
||||||
* load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc.
|
* load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc, where offset_byte is in_idxs[1]
|
||||||
* is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values.
|
* is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values.
|
||||||
* fill_value: when load_num can not fully fit in vector register, what values should be filled as default values.
|
* fill_value: when load_num can not fully fit in vector register, what values should be filled as default values.
|
||||||
* currently support "zero", "int_one", "float_one", "int32_min", "float_min", "int32_max" and "float_max".
|
* currently support "zero", "int_one", "float_one", "int32_min", "float_min", "int32_max" and "float_max".
|
||||||
@ -66,27 +63,23 @@ public:
|
|||||||
* dst_prc
|
* dst_prc
|
||||||
*/
|
*/
|
||||||
void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
||||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
|
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
|
||||||
const emitter_context *emit_context) const override;
|
const emitter_context *emit_context) const override;
|
||||||
|
|
||||||
size_t get_inputs_num() const override;
|
size_t get_inputs_num() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||||
void emit_isa(const Xbyak::Reg64 ®_src, int offset_byte, InferenceEngine::Precision src_prc,
|
void emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_idx, const int offset) const;
|
||||||
const int out_vec_idx, InferenceEngine::Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero") const;
|
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size,
|
void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size) const;
|
||||||
bool is_fill = false, std::string fill_value = "zero") const;
|
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size,
|
void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size) const;
|
||||||
bool is_fill = false, std::string fill_value = "zero") const;
|
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size,
|
void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size) const;
|
||||||
bool is_fill = false, std::string fill_value = "zero") const;
|
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const;
|
void fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const;
|
||||||
@ -95,17 +88,23 @@ private:
|
|||||||
|
|
||||||
size_t aux_gprs_count() const override;
|
size_t aux_gprs_count() const override;
|
||||||
|
|
||||||
std::string name;
|
std::string name_;
|
||||||
int v_len_elt; // 4/8/16
|
int v_len_elt_; // 4/8/16
|
||||||
|
int load_num_;
|
||||||
|
int load_size_;
|
||||||
|
Precision src_prc_;
|
||||||
|
Precision dst_prc_;
|
||||||
|
bool is_fill_;
|
||||||
|
std::string fill_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class jit_store_emitter : public jit_emitter {
|
class jit_store_emitter : public jit_emitter {
|
||||||
public:
|
public:
|
||||||
jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
|
jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, Precision src_prc, Precision dst_prc, int store_num,
|
||||||
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr);
|
Precision exec_prc = Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data.
|
* store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data, where offset_byte is in_idxs[1]
|
||||||
* supported src_prc and dst_prc pairs are as below(x indicate for support):
|
* supported src_prc and dst_prc pairs are as below(x indicate for support):
|
||||||
* FP32 I32 I16 U16 I8 U8 BF16 --> src_prc
|
* FP32 I32 I16 U16 I8 U8 BF16 --> src_prc
|
||||||
* FP32 x x
|
* FP32 x x
|
||||||
@ -120,21 +119,20 @@ public:
|
|||||||
* note: FP32/I32-->BF16(x*) is supported only on at least avx512-core plateform
|
* note: FP32/I32-->BF16(x*) is supported only on at least avx512-core plateform
|
||||||
*/
|
*/
|
||||||
void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
|
||||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
|
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
|
||||||
const emitter_context *emit_context) const override;
|
const emitter_context *emit_context) const override;
|
||||||
|
|
||||||
size_t get_inputs_num() const override;
|
size_t get_inputs_num() const override;
|
||||||
|
|
||||||
void emit_data() const override;
|
void emit_data() const override;
|
||||||
|
|
||||||
std::shared_ptr<jit_emu_vcvtneps2bf16> get_emu_vcvtneps2bf16() const {
|
std::shared_ptr<jit_emu_vcvtneps2bf16> get_emu_vcvtneps2bf16() const {
|
||||||
return emu_vcvtneps2bf16;
|
return emu_vcvtneps2bf16_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||||
void emit_isa(const int in_vec_idx, InferenceEngine::Precision src_prc,
|
void emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_dst, const int offset) const;
|
||||||
const Xbyak::Reg64 ®_dst, int offset_byte, InferenceEngine::Precision dst_prc, int store_num) const;
|
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const;
|
void store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const;
|
||||||
@ -148,9 +146,13 @@ private:
|
|||||||
size_t aux_gprs_count() const override;
|
size_t aux_gprs_count() const override;
|
||||||
size_t aux_vecs_count() const override;
|
size_t aux_vecs_count() const override;
|
||||||
|
|
||||||
std::string name;
|
std::string name_;
|
||||||
int v_len_elt; // 4/8/16
|
int v_len_elt_; // 4/8/16
|
||||||
std::shared_ptr<jit_emu_vcvtneps2bf16> emu_vcvtneps2bf16;
|
int store_num_;
|
||||||
|
int store_size_;
|
||||||
|
Precision src_prc_;
|
||||||
|
Precision dst_prc_;
|
||||||
|
std::shared_ptr<jit_emu_vcvtneps2bf16> emu_vcvtneps2bf16_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace intel_cpu
|
} // namespace intel_cpu
|
||||||
|
@ -58,9 +58,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generate() override {
|
void generate() override {
|
||||||
load_emitter.reset(new jit_load_emitter(this, isa));
|
|
||||||
store_emitter.reset(new jit_store_emitter(this, isa));
|
|
||||||
|
|
||||||
// dummy second reg_tmp_64 as no fill needed
|
// dummy second reg_tmp_64 as no fill needed
|
||||||
load_pool_gpr_idxs = {static_cast<size_t>(reg_tmp_64.getIdx()), static_cast<size_t>(reg_tmp_64.getIdx())};
|
load_pool_gpr_idxs = {static_cast<size_t>(reg_tmp_64.getIdx()), static_cast<size_t>(reg_tmp_64.getIdx())};
|
||||||
store_pool_gpr_idxs = {static_cast<size_t>(reg_tmp_64.getIdx())};
|
store_pool_gpr_idxs = {static_cast<size_t>(reg_tmp_64.getIdx())};
|
||||||
@ -162,8 +159,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
|
|||||||
|
|
||||||
this->postamble();
|
this->postamble();
|
||||||
|
|
||||||
load_emitter->emit_data();
|
emit_emitters_data();
|
||||||
store_emitter->emit_data();
|
|
||||||
for (auto& inj : eltwise_injectors)
|
for (auto& inj : eltwise_injectors)
|
||||||
inj->prepare_table();
|
inj->prepare_table();
|
||||||
if ((jcp_.mode == InterpolateMode::cubic) && (jcp_.layout == InterpolateLayoutType::planar)) {
|
if ((jcp_.mode == InterpolateMode::cubic) && (jcp_.layout == InterpolateLayoutType::planar)) {
|
||||||
@ -176,6 +172,9 @@ private:
|
|||||||
Xbyak::Ymm, Xbyak::Zmm>::type;
|
Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||||
|
|
||||||
const int vlen = cpu_isa_traits<isa>::vlen;
|
const int vlen = cpu_isa_traits<isa>::vlen;
|
||||||
|
const int vector_step = vlen / sizeof(float);
|
||||||
|
const int tail_step = jcp_.C % vector_step;
|
||||||
|
const int scalar_step = 1;
|
||||||
|
|
||||||
Xbyak::Reg64 reg_src = r8;
|
Xbyak::Reg64 reg_src = r8;
|
||||||
Xbyak::Reg64 reg_src_aux = r15;
|
Xbyak::Reg64 reg_src_aux = r15;
|
||||||
@ -246,8 +245,8 @@ private:
|
|||||||
Xbyak::Label l_table_constant;
|
Xbyak::Label l_table_constant;
|
||||||
Opmask k_mask = Xbyak::Opmask(1);
|
Opmask k_mask = Xbyak::Opmask(1);
|
||||||
|
|
||||||
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
|
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
|
||||||
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
|
||||||
std::vector<size_t> store_pool_gpr_idxs;
|
std::vector<size_t> store_pool_gpr_idxs;
|
||||||
std::vector<size_t> store_pool_vec_idxs;
|
std::vector<size_t> store_pool_vec_idxs;
|
||||||
std::vector<size_t> load_pool_gpr_idxs;
|
std::vector<size_t> load_pool_gpr_idxs;
|
||||||
@ -256,20 +255,44 @@ private:
|
|||||||
std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
|
std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
|
||||||
std::vector<std::shared_ptr<jit_uni_quantization_injector_f32<isa>>> quantization_injectors;
|
std::vector<std::shared_ptr<jit_uni_quantization_injector_f32<isa>>> quantization_injectors;
|
||||||
|
|
||||||
inline void load(const Xbyak::Reg64& reg_src, Vmm& vmm, const int& elt_num, const int& offset = 0) {
|
void emit_emitters_data() {
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm.getIdx())},
|
for (const auto& emitter : emitters) {
|
||||||
std::make_shared<load_emitter_context>(jcp_.src_prc, Precision::FP32, elt_num, offset),
|
if (emitter.second)
|
||||||
{}, {load_pool_gpr_idxs});
|
emitter.second->emit_data();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
inline void store(const Vmm& vmm, const Xbyak::Reg64& reg_dst, const int& elt_num, const int& offset = 0) {
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
|
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.dst_prc, elt_num, offset),
|
emit_load(reg_src, vmm_src, jcp_.src_prc, Precision::FP32, elt_num, offset);
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
}
|
}
|
||||||
inline void load_weights(const Xbyak::Reg64& reg_weights, Vmm& vmm, const int& elt_num, const int& offset = 0) {
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_weights.getIdx())}, {static_cast<size_t>(vmm.getIdx())},
|
inline void load_weights(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
std::make_shared<load_emitter_context>(Precision::FP32, Precision::FP32, elt_num, offset),
|
emit_load(reg_src, vmm_src, Precision::FP32, Precision::FP32, elt_num, offset);
|
||||||
{}, {load_pool_gpr_idxs});
|
}
|
||||||
|
|
||||||
|
inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
|
||||||
|
const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash();
|
||||||
|
if (!emitters[seed]) {
|
||||||
|
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num));
|
||||||
|
}
|
||||||
|
|
||||||
|
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), static_cast<size_t>(offset)},
|
||||||
|
{static_cast<size_t>(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||||
|
const auto seed = store_emitter_params(Precision::FP32, jcp_.dst_prc, elt_num).hash();
|
||||||
|
if (!emitters[seed]) {
|
||||||
|
emitters[seed].reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, elt_num));
|
||||||
|
}
|
||||||
|
|
||||||
|
// for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm
|
||||||
|
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm_dst.getIdx()) };
|
||||||
|
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
|
||||||
|
|
||||||
|
emitters[seed]->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(offset)},
|
||||||
|
{static_cast<size_t>(reg_dst.getIdx())},
|
||||||
|
{local_store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
||||||
}
|
}
|
||||||
|
|
||||||
void nn_planar() {
|
void nn_planar() {
|
||||||
@ -303,7 +326,6 @@ private:
|
|||||||
|
|
||||||
// reset index_w, index_w * dataSize done when built to avoid redundent compute
|
// reset index_w, index_w * dataSize done when built to avoid redundent compute
|
||||||
mov(reg_index, reg_index_w);
|
mov(reg_index, reg_index_w);
|
||||||
int step = vlen / sizeof(float);
|
|
||||||
|
|
||||||
Xbyak::Label nn_loop_label;
|
Xbyak::Label nn_loop_label;
|
||||||
Xbyak::Label nn_loop_end_label;
|
Xbyak::Label nn_loop_end_label;
|
||||||
@ -312,7 +334,7 @@ private:
|
|||||||
|
|
||||||
L(nn_loop_label); // inner loop
|
L(nn_loop_label); // inner loop
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(nn_loop_end_label, T_NEAR);
|
jl(nn_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
uni_vmovdqu(vmm_index, ptr[reg_index]);
|
uni_vmovdqu(vmm_index, ptr[reg_index]);
|
||||||
@ -320,17 +342,16 @@ private:
|
|||||||
vgatherdps(vmm_val, ptr[reg_src_h + vmm_index], vmm_mask);
|
vgatherdps(vmm_val, ptr[reg_src_h + vmm_index], vmm_mask);
|
||||||
if (attr_.post_ops_.len() != 0)
|
if (attr_.post_ops_.len() != 0)
|
||||||
apply_post_ops(jcp_.dst_prc, 1);
|
apply_post_ops(jcp_.dst_prc, 1);
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, vector_step);
|
||||||
|
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
add(reg_index, step * jcp_.indices_size);
|
add(reg_index, vector_step * jcp_.indices_size);
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, vector_step);
|
||||||
|
|
||||||
jmp(nn_loop_label, T_NEAR);
|
jmp(nn_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
L(nn_loop_end_label);
|
L(nn_loop_end_label);
|
||||||
|
|
||||||
step = 1;
|
|
||||||
L(nn_tail_loop_label);
|
L(nn_tail_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, 1);
|
cmp(reg_work_amount, 1);
|
||||||
@ -340,14 +361,14 @@ private:
|
|||||||
mov(reg_index_offset, dword[reg_index]);
|
mov(reg_index_offset, dword[reg_index]);
|
||||||
add(reg_src_aux, reg_index_offset);
|
add(reg_src_aux, reg_index_offset);
|
||||||
|
|
||||||
load(reg_src_aux, vmm_val, step);
|
load(reg_src_aux, vmm_val, scalar_step);
|
||||||
if (attr_.post_ops_.len() != 0)
|
if (attr_.post_ops_.len() != 0)
|
||||||
apply_post_ops(jcp_.dst_prc, 1);
|
apply_post_ops(jcp_.dst_prc, 1);
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, scalar_step);
|
||||||
|
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, scalar_step * jcp_.dst_data_size);
|
||||||
add(reg_index, step * jcp_.indices_size);
|
add(reg_index, scalar_step * jcp_.indices_size);
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, scalar_step);
|
||||||
|
|
||||||
jmp(nn_tail_loop_label, T_NEAR);
|
jmp(nn_tail_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
@ -363,8 +384,6 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void nn_blk() {
|
void nn_blk() {
|
||||||
int step = vlen / sizeof(float);
|
|
||||||
|
|
||||||
Xbyak::Label nn_loop_label;
|
Xbyak::Label nn_loop_label;
|
||||||
Xbyak::Label nn_loop_end_label;
|
Xbyak::Label nn_loop_end_label;
|
||||||
L(nn_loop_label);
|
L(nn_loop_label);
|
||||||
@ -376,22 +395,22 @@ private:
|
|||||||
mov(reg_index_offset, dword[reg_index]);
|
mov(reg_index_offset, dword[reg_index]);
|
||||||
add(reg_src_aux, reg_index_offset);
|
add(reg_src_aux, reg_index_offset);
|
||||||
|
|
||||||
load(reg_src_aux, vmm_val, step);
|
load(reg_src_aux, vmm_val, vector_step);
|
||||||
if (attr_.post_ops_.len() != 0)
|
if (attr_.post_ops_.len() != 0)
|
||||||
apply_post_ops(jcp_.dst_prc, 0);
|
apply_post_ops(jcp_.dst_prc, 0);
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, vector_step);
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
|
|
||||||
if (isa == cpu::x64::sse41) {
|
if (isa == cpu::x64::sse41) {
|
||||||
add(reg_src_aux, step * jcp_.src_data_size);
|
add(reg_src_aux, vector_step * jcp_.src_data_size);
|
||||||
load(reg_src_aux, vmm_val, step);
|
load(reg_src_aux, vmm_val, vector_step);
|
||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
add(reg_oc_off, step * sizeof(float));
|
add(reg_oc_off, vector_step * sizeof(float));
|
||||||
apply_post_ops(jcp_.dst_prc, 0);
|
apply_post_ops(jcp_.dst_prc, 0);
|
||||||
sub(reg_oc_off, step * sizeof(float));
|
sub(reg_oc_off, vector_step * sizeof(float));
|
||||||
}
|
}
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, vector_step);
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
add(reg_index, jcp_.indices_size);
|
add(reg_index, jcp_.indices_size);
|
||||||
@ -421,8 +440,6 @@ private:
|
|||||||
cmp(reg_work_amount_out, 1);
|
cmp(reg_work_amount_out, 1);
|
||||||
jl(out_loop_end, T_NEAR);
|
jl(out_loop_end, T_NEAR);
|
||||||
|
|
||||||
int step = vlen / sizeof(float);
|
|
||||||
|
|
||||||
//inner loop for C
|
//inner loop for C
|
||||||
Xbyak::Label nn_loop_label;
|
Xbyak::Label nn_loop_label;
|
||||||
Xbyak::Label nn_loop_end_label;
|
Xbyak::Label nn_loop_end_label;
|
||||||
@ -444,35 +461,34 @@ private:
|
|||||||
|
|
||||||
L(nn_loop_label);
|
L(nn_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(nn_loop_end_label, T_NEAR);
|
jl(nn_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
load(reg_src_aux, vmm_val, step);
|
load(reg_src_aux, vmm_val, vector_step);
|
||||||
if (attr_.post_ops_.len() != 0)
|
if (attr_.post_ops_.len() != 0)
|
||||||
apply_post_ops(jcp_.dst_prc, 0);
|
apply_post_ops(jcp_.dst_prc, 0);
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, vector_step);
|
||||||
|
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
add(reg_src_aux, step * jcp_.src_data_size);
|
add(reg_src_aux, vector_step * jcp_.src_data_size);
|
||||||
add(reg_oc_off, step * sizeof(float));
|
add(reg_oc_off, vector_step * sizeof(float));
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, vector_step);
|
||||||
|
|
||||||
jmp(nn_loop_label, T_NEAR);
|
jmp(nn_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
L(nn_loop_end_label);
|
L(nn_loop_end_label);
|
||||||
|
|
||||||
int tail_num = jcp_.C % step;
|
if (tail_step != 0) {
|
||||||
if (tail_num != 0) {
|
load(reg_src_aux, vmm_val, tail_step);
|
||||||
load(reg_src_aux, vmm_val, tail_num);
|
|
||||||
if (attr_.post_ops_.len() != 0)
|
if (attr_.post_ops_.len() != 0)
|
||||||
apply_post_ops(jcp_.dst_prc, 0);
|
apply_post_ops(jcp_.dst_prc, 0);
|
||||||
store(vmm_val, reg_dst, tail_num);
|
store(vmm_val, reg_dst, tail_step);
|
||||||
|
|
||||||
// check to remove below
|
// check to remove below
|
||||||
add(reg_dst, tail_num * jcp_.dst_data_size);
|
add(reg_dst, tail_step * jcp_.dst_data_size);
|
||||||
add(reg_src_aux, tail_num * jcp_.src_data_size);
|
add(reg_src_aux, tail_step * jcp_.src_data_size);
|
||||||
add(reg_oc_off, tail_num * sizeof(float));
|
add(reg_oc_off, tail_step * sizeof(float));
|
||||||
sub(reg_work_amount, tail_num);
|
sub(reg_work_amount, tail_step);
|
||||||
}
|
}
|
||||||
add(reg_index, jcp_.indices_size);
|
add(reg_index, jcp_.indices_size);
|
||||||
sub(reg_work_amount_out, 1);
|
sub(reg_work_amount_out, 1);
|
||||||
@ -519,11 +535,10 @@ private:
|
|||||||
}
|
}
|
||||||
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
||||||
|
|
||||||
int step = vlen / sizeof(float);
|
int blk = (isa == cpu::x64::sse41) ? (2 * vector_step) : vector_step;
|
||||||
int blk = (isa == cpu::x64::sse41) ? (2 * step) : step;
|
int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.dst_data_size) :
|
||||||
int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (step * jcp_.dst_data_size) :
|
|
||||||
(blk * jcp_.OW * jcp_.OH * jcp_.OD * jcp_.dst_data_size);
|
(blk * jcp_.OW * jcp_.OH * jcp_.OD * jcp_.dst_data_size);
|
||||||
int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (step * jcp_.src_data_size) :
|
int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.src_data_size) :
|
||||||
(blk * jcp_.IW * jcp_.IH * jcp_.ID * jcp_.src_data_size);
|
(blk * jcp_.IW * jcp_.IH * jcp_.ID * jcp_.src_data_size);
|
||||||
|
|
||||||
Xbyak::Label main_loop_label;
|
Xbyak::Label main_loop_label;
|
||||||
@ -535,29 +550,29 @@ private:
|
|||||||
L(main_loop_label);
|
L(main_loop_label);
|
||||||
{
|
{
|
||||||
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(main_loop_end_label, T_NEAR);
|
jl(main_loop_end_label, T_NEAR);
|
||||||
} else {
|
} else {
|
||||||
cmp(reg_work_amount, 1);
|
cmp(reg_work_amount, 1);
|
||||||
jl(main_loop_end_label, T_NEAR);
|
jl(main_loop_end_label, T_NEAR);
|
||||||
}
|
}
|
||||||
// progressive manner
|
// progressive manner
|
||||||
load(reg_src, vmm_valTL, step);
|
load(reg_src, vmm_valTL, vector_step);
|
||||||
load(reg_src_aux, vmm_valTR, step);
|
load(reg_src_aux, vmm_valTR, vector_step);
|
||||||
if (jcp_.spatial_dim_size == 1) {
|
if (jcp_.spatial_dim_size == 1) {
|
||||||
linear_onnx_worker_1d();
|
linear_onnx_worker_1d();
|
||||||
}
|
}
|
||||||
if (jcp_.spatial_dim_size > 1) {
|
if (jcp_.spatial_dim_size > 1) {
|
||||||
load(reg_src_aux1, vmm_valBL, step);
|
load(reg_src_aux1, vmm_valBL, vector_step);
|
||||||
load(reg_src_aux2, vmm_valBR, step);
|
load(reg_src_aux2, vmm_valBR, vector_step);
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
}
|
}
|
||||||
if (jcp_.spatial_dim_size > 2) {
|
if (jcp_.spatial_dim_size > 2) {
|
||||||
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
|
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
|
||||||
load(reg_src_aux4, vmm_valTL, step);
|
load(reg_src_aux4, vmm_valTL, vector_step);
|
||||||
load(reg_src_aux5, vmm_valTR, step);
|
load(reg_src_aux5, vmm_valTR, vector_step);
|
||||||
load(reg_src_aux6, vmm_valBL, step);
|
load(reg_src_aux6, vmm_valBL, vector_step);
|
||||||
load(reg_src_aux7, vmm_valBR, step);
|
load(reg_src_aux7, vmm_valBR, vector_step);
|
||||||
|
|
||||||
// 2d for end depth
|
// 2d for end depth
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
@ -568,28 +583,28 @@ private:
|
|||||||
|
|
||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR
|
apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR
|
||||||
add(reg_oc_off, step * sizeof(float));
|
add(reg_oc_off, vector_step * sizeof(float));
|
||||||
}
|
}
|
||||||
store(vmm_valTR, reg_dst, step);
|
store(vmm_valTR, reg_dst, vector_step);
|
||||||
|
|
||||||
if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) {
|
if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) {
|
||||||
int offset_src = step * jcp_.src_data_size;
|
int offset_src = vector_step * jcp_.src_data_size;
|
||||||
load(reg_src, vmm_valTL, step, offset_src);
|
load(reg_src, vmm_valTL, vector_step, offset_src);
|
||||||
load(reg_src_aux, vmm_valTR, step, offset_src);
|
load(reg_src_aux, vmm_valTR, vector_step, offset_src);
|
||||||
if (jcp_.spatial_dim_size == 1) {
|
if (jcp_.spatial_dim_size == 1) {
|
||||||
linear_onnx_worker_1d();
|
linear_onnx_worker_1d();
|
||||||
}
|
}
|
||||||
if (jcp_.spatial_dim_size > 1) {
|
if (jcp_.spatial_dim_size > 1) {
|
||||||
load(reg_src_aux1, vmm_valBL, step, offset_src);
|
load(reg_src_aux1, vmm_valBL, vector_step, offset_src);
|
||||||
load(reg_src_aux2, vmm_valBR, step, offset_src);
|
load(reg_src_aux2, vmm_valBR, vector_step, offset_src);
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
}
|
}
|
||||||
if (jcp_.spatial_dim_size > 2) {
|
if (jcp_.spatial_dim_size > 2) {
|
||||||
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
|
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
|
||||||
load(reg_src_aux4, vmm_valTL, step, offset_src);
|
load(reg_src_aux4, vmm_valTL, vector_step, offset_src);
|
||||||
load(reg_src_aux5, vmm_valTR, step, offset_src);
|
load(reg_src_aux5, vmm_valTR, vector_step, offset_src);
|
||||||
load(reg_src_aux6, vmm_valBL, step, offset_src);
|
load(reg_src_aux6, vmm_valBL, vector_step, offset_src);
|
||||||
load(reg_src_aux7, vmm_valBR, step, offset_src);
|
load(reg_src_aux7, vmm_valBR, vector_step, offset_src);
|
||||||
// 2d for end depth
|
// 2d for end depth
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
// 3th dimension
|
// 3th dimension
|
||||||
@ -599,10 +614,10 @@ private:
|
|||||||
|
|
||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, false);
|
apply_post_ops(jcp_.dst_prc, false);
|
||||||
add(reg_oc_off, step * sizeof(float));
|
add(reg_oc_off, vector_step * sizeof(float));
|
||||||
}
|
}
|
||||||
int offset_dst = step * jcp_.dst_data_size;
|
int offset_dst = vector_step * jcp_.dst_data_size;
|
||||||
store(vmm_valTR, reg_dst, step, offset_dst);
|
store(vmm_valTR, reg_dst, vector_step, offset_dst);
|
||||||
}
|
}
|
||||||
add(reg_dst, dst_stride);
|
add(reg_dst, dst_stride);
|
||||||
add(reg_src, src_stride);
|
add(reg_src, src_stride);
|
||||||
@ -618,7 +633,7 @@ private:
|
|||||||
add(reg_src_aux7, src_stride);
|
add(reg_src_aux7, src_stride);
|
||||||
}
|
}
|
||||||
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
||||||
sub(reg_work_amount, step); // work_amount is c
|
sub(reg_work_amount, vector_step); // work_amount is c
|
||||||
} else {
|
} else {
|
||||||
sub(reg_work_amount, 1); // work_amount = div_up(c, blk), no tails
|
sub(reg_work_amount, 1); // work_amount = div_up(c, blk), no tails
|
||||||
}
|
}
|
||||||
@ -627,25 +642,24 @@ private:
|
|||||||
}
|
}
|
||||||
L(main_loop_end_label);
|
L(main_loop_end_label);
|
||||||
|
|
||||||
int tail_num = jcp_.C % step;
|
if ((jcp_.layout == InterpolateLayoutType::by_channel) && (tail_step != 0)) {
|
||||||
if ((jcp_.layout == InterpolateLayoutType::by_channel) && (tail_num != 0)) {
|
load(reg_src, vmm_valTL, tail_step);
|
||||||
load(reg_src, vmm_valTL, tail_num);
|
load(reg_src_aux, vmm_valTR, tail_step);
|
||||||
load(reg_src_aux, vmm_valTR, tail_num);
|
|
||||||
if (jcp_.spatial_dim_size == 1) {
|
if (jcp_.spatial_dim_size == 1) {
|
||||||
linear_onnx_worker_1d();
|
linear_onnx_worker_1d();
|
||||||
}
|
}
|
||||||
if (jcp_.spatial_dim_size > 1) {
|
if (jcp_.spatial_dim_size > 1) {
|
||||||
load(reg_src_aux1, vmm_valBL, tail_num);
|
load(reg_src_aux1, vmm_valBL, tail_step);
|
||||||
load(reg_src_aux2, vmm_valBR, tail_num);
|
load(reg_src_aux2, vmm_valBR, tail_step);
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
}
|
}
|
||||||
if (jcp_.spatial_dim_size > 2) {
|
if (jcp_.spatial_dim_size > 2) {
|
||||||
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
|
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
|
||||||
|
|
||||||
load(reg_src_aux4, vmm_valTL, tail_num);
|
load(reg_src_aux4, vmm_valTL, tail_step);
|
||||||
load(reg_src_aux5, vmm_valTR, tail_num);
|
load(reg_src_aux5, vmm_valTR, tail_step);
|
||||||
load(reg_src_aux6, vmm_valBL, tail_num);
|
load(reg_src_aux6, vmm_valBL, tail_step);
|
||||||
load(reg_src_aux7, vmm_valBR, tail_num);
|
load(reg_src_aux7, vmm_valBR, tail_step);
|
||||||
// 2d for end depth
|
// 2d for end depth
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
// 3th dimension
|
// 3th dimension
|
||||||
@ -655,10 +669,10 @@ private:
|
|||||||
|
|
||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR
|
apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR
|
||||||
add(reg_oc_off, tail_num * sizeof(float));
|
add(reg_oc_off, tail_step * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
store(vmm_valTR, reg_dst, tail_num);
|
store(vmm_valTR, reg_dst, tail_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -669,7 +683,6 @@ private:
|
|||||||
mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0])]);
|
mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0])]);
|
||||||
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
||||||
|
|
||||||
int step = vlen / sizeof(float);
|
|
||||||
int index_stride = jcp_.OW * jcp_.OH * jcp_.OD * jcp_.indices_size;
|
int index_stride = jcp_.OW * jcp_.OH * jcp_.OD * jcp_.indices_size;
|
||||||
int weight_stride = jcp_.OW * jcp_.OH * jcp_.OD * sizeof(float);
|
int weight_stride = jcp_.OW * jcp_.OH * jcp_.OD * sizeof(float);
|
||||||
|
|
||||||
@ -679,7 +692,7 @@ private:
|
|||||||
Xbyak::Label tail_loop_end_label;
|
Xbyak::Label tail_loop_end_label;
|
||||||
L(main_loop_label);
|
L(main_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(main_loop_end_label, T_NEAR);
|
jl(main_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
uni_vmovdqu(vmm_index, ptr[reg_index]);
|
uni_vmovdqu(vmm_index, ptr[reg_index]);
|
||||||
@ -690,8 +703,8 @@ private:
|
|||||||
uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
|
uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
|
||||||
vgatherdps(vmm_valTR, ptr[reg_src + vmm_index], vmm_mask);
|
vgatherdps(vmm_valTR, ptr[reg_src + vmm_index], vmm_mask);
|
||||||
|
|
||||||
load_weights(reg_src_aux, vmm_weightL, step);
|
load_weights(reg_src_aux, vmm_weightL, vector_step);
|
||||||
load_weights(reg_src_aux, vmm_weightR, step, weight_stride);
|
load_weights(reg_src_aux, vmm_weightR, vector_step, weight_stride);
|
||||||
|
|
||||||
// progressive manner
|
// progressive manner
|
||||||
if (jcp_.spatial_dim_size == 1) {
|
if (jcp_.spatial_dim_size == 1) {
|
||||||
@ -706,8 +719,8 @@ private:
|
|||||||
uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
|
uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
|
||||||
vgatherdps(vmm_valBR, ptr[reg_src + vmm_index], vmm_mask);
|
vgatherdps(vmm_valBR, ptr[reg_src + vmm_index], vmm_mask);
|
||||||
|
|
||||||
load_weights(reg_src_aux, vmm_weightT, step, 2 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightT, vector_step, 2 * weight_stride);
|
||||||
load_weights(reg_src_aux, vmm_weightB, step, 3 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightB, vector_step, 3 * weight_stride);
|
||||||
|
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
}
|
}
|
||||||
@ -733,8 +746,8 @@ private:
|
|||||||
|
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
|
|
||||||
load_weights(reg_src_aux, vmm_weightE, step, 5 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightE, vector_step, 5 * weight_stride);
|
||||||
load_weights(reg_src_aux, vmm_weightF, step, 4 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightF, vector_step, 4 * weight_stride);
|
||||||
|
|
||||||
uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight
|
uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight
|
||||||
uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight
|
uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight
|
||||||
@ -743,18 +756,17 @@ private:
|
|||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, true); // vmm_val is vmm_valTR, broadcase is true
|
apply_post_ops(jcp_.dst_prc, true); // vmm_val is vmm_valTR, broadcase is true
|
||||||
}
|
}
|
||||||
store(vmm_valTR, reg_dst, step);
|
store(vmm_valTR, reg_dst, vector_step);
|
||||||
|
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
add(reg_src_aux, step * sizeof(float));
|
add(reg_src_aux, vector_step * sizeof(float));
|
||||||
add(reg_index, step * jcp_.indices_size);
|
add(reg_index, vector_step * jcp_.indices_size);
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, vector_step);
|
||||||
|
|
||||||
jmp(main_loop_label, T_NEAR);
|
jmp(main_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
L(main_loop_end_label);
|
L(main_loop_end_label);
|
||||||
|
|
||||||
step = 1;
|
|
||||||
L(tail_loop_label);
|
L(tail_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, 1);
|
cmp(reg_work_amount, 1);
|
||||||
@ -763,15 +775,15 @@ private:
|
|||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index]);
|
mov(reg_index_offset, dword[reg_index]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valTL, step);
|
load(reg_src_aux1, vmm_valTL, scalar_step);
|
||||||
|
|
||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index + index_stride]);
|
mov(reg_index_offset, dword[reg_index + index_stride]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valTR, step);
|
load(reg_src_aux1, vmm_valTR, scalar_step);
|
||||||
|
|
||||||
load_weights(reg_src_aux, vmm_weightL, step, 0);
|
load_weights(reg_src_aux, vmm_weightL, scalar_step, 0);
|
||||||
load_weights(reg_src_aux, vmm_weightR, step, weight_stride);
|
load_weights(reg_src_aux, vmm_weightR, scalar_step, weight_stride);
|
||||||
|
|
||||||
if (jcp_.spatial_dim_size == 1) {
|
if (jcp_.spatial_dim_size == 1) {
|
||||||
linear_onnx_worker_1d();
|
linear_onnx_worker_1d();
|
||||||
@ -780,15 +792,15 @@ private:
|
|||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index + 2 * index_stride]);
|
mov(reg_index_offset, dword[reg_index + 2 * index_stride]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valBL, step);
|
load(reg_src_aux1, vmm_valBL, scalar_step);
|
||||||
|
|
||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index + 3 * index_stride]);
|
mov(reg_index_offset, dword[reg_index + 3 * index_stride]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valBR, step);
|
load(reg_src_aux1, vmm_valBR, scalar_step);
|
||||||
|
|
||||||
load_weights(reg_src_aux, vmm_weightT, step, 2 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightT, scalar_step, 2 * weight_stride);
|
||||||
load_weights(reg_src_aux, vmm_weightB, step, 3 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightB, scalar_step, 3 * weight_stride);
|
||||||
|
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
}
|
}
|
||||||
@ -799,27 +811,27 @@ private:
|
|||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index + 4 * index_stride]);
|
mov(reg_index_offset, dword[reg_index + 4 * index_stride]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valTL, step);
|
load(reg_src_aux1, vmm_valTL, scalar_step);
|
||||||
|
|
||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index + 5 * index_stride]);
|
mov(reg_index_offset, dword[reg_index + 5 * index_stride]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valTR, step);
|
load(reg_src_aux1, vmm_valTR, scalar_step);
|
||||||
|
|
||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index + 6 * index_stride]);
|
mov(reg_index_offset, dword[reg_index + 6 * index_stride]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valBL, step);
|
load(reg_src_aux1, vmm_valBL, scalar_step);
|
||||||
|
|
||||||
mov(reg_src_aux1, reg_src);
|
mov(reg_src_aux1, reg_src);
|
||||||
mov(reg_index_offset, dword[reg_index + 7 * index_stride]);
|
mov(reg_index_offset, dword[reg_index + 7 * index_stride]);
|
||||||
add(reg_src_aux1, reg_index_offset);
|
add(reg_src_aux1, reg_index_offset);
|
||||||
load(reg_src_aux1, vmm_valBR, step);
|
load(reg_src_aux1, vmm_valBR, scalar_step);
|
||||||
|
|
||||||
linear_onnx_worker_2d();
|
linear_onnx_worker_2d();
|
||||||
|
|
||||||
load_weights(reg_src_aux, vmm_weightE, step, 5 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightE, scalar_step, 5 * weight_stride);
|
||||||
load_weights(reg_src_aux, vmm_weightF, step, 4 * weight_stride);
|
load_weights(reg_src_aux, vmm_weightF, scalar_step, 4 * weight_stride);
|
||||||
|
|
||||||
uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight
|
uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight
|
||||||
uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight
|
uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight
|
||||||
@ -828,12 +840,12 @@ private:
|
|||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, true); // process on vmm_val, vmm_val is vmm_valTR, and bc
|
apply_post_ops(jcp_.dst_prc, true); // process on vmm_val, vmm_val is vmm_valTR, and bc
|
||||||
}
|
}
|
||||||
store(vmm_valTR, reg_dst, step);
|
store(vmm_valTR, reg_dst, scalar_step);
|
||||||
|
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, scalar_step * jcp_.dst_data_size);
|
||||||
add(reg_src_aux, step * sizeof(float));
|
add(reg_src_aux, scalar_step * sizeof(float));
|
||||||
add(reg_index, step * jcp_.indices_size);
|
add(reg_index, scalar_step * jcp_.indices_size);
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, scalar_step);
|
||||||
|
|
||||||
jmp(tail_loop_label, T_NEAR);
|
jmp(tail_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
@ -876,8 +888,7 @@ private:
|
|||||||
uni_vbroadcastss(vmm_weightY2, ptr[reg_src_aux1 + 2 * sizeof(float)]);
|
uni_vbroadcastss(vmm_weightY2, ptr[reg_src_aux1 + 2 * sizeof(float)]);
|
||||||
uni_vbroadcastss(vmm_weightY3, ptr[reg_src_aux1 + 3 * sizeof(float)]);
|
uni_vbroadcastss(vmm_weightY3, ptr[reg_src_aux1 + 3 * sizeof(float)]);
|
||||||
|
|
||||||
int step = vlen / sizeof(float);
|
int blk = (isa == cpu::x64::sse41) ? (2 * vector_step) : vector_step;
|
||||||
int blk = (isa == cpu::x64::sse41) ? (2 * step) : step;
|
|
||||||
|
|
||||||
Xbyak::Label main_loop_label;
|
Xbyak::Label main_loop_label;
|
||||||
Xbyak::Label main_loop_end_label;
|
Xbyak::Label main_loop_end_label;
|
||||||
@ -886,7 +897,7 @@ private:
|
|||||||
L(main_loop_label);
|
L(main_loop_label);
|
||||||
{
|
{
|
||||||
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(main_loop_end_label, T_NEAR);
|
jl(main_loop_end_label, T_NEAR);
|
||||||
} else {
|
} else {
|
||||||
cmp(reg_work_amount, 1);
|
cmp(reg_work_amount, 1);
|
||||||
@ -899,14 +910,14 @@ private:
|
|||||||
|
|
||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value to post_ops and store
|
apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value to post_ops and store
|
||||||
add(reg_oc_off, step * sizeof(float));
|
add(reg_oc_off, vector_step * sizeof(float));
|
||||||
}
|
}
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, vector_step);
|
||||||
|
|
||||||
if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) {
|
if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) {
|
||||||
// vmm is xmm here
|
// vmm is xmm here
|
||||||
add(reg_src, step * jcp_.src_data_size);
|
add(reg_src, vector_step * jcp_.src_data_size);
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
|
|
||||||
uni_vpxor(vmm_val, vmm_val, vmm_val);
|
uni_vpxor(vmm_val, vmm_val, vmm_val);
|
||||||
|
|
||||||
@ -914,19 +925,19 @@ private:
|
|||||||
|
|
||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, false);
|
apply_post_ops(jcp_.dst_prc, false);
|
||||||
add(reg_oc_off, step * sizeof(float)); // second step for one blk
|
add(reg_oc_off, vector_step * sizeof(float)); // second vector_step for one blk
|
||||||
}
|
}
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, vector_step);
|
||||||
|
|
||||||
sub(reg_src, step * jcp_.src_data_size);
|
sub(reg_src, vector_step * jcp_.src_data_size);
|
||||||
sub(reg_dst, step * jcp_.dst_data_size);
|
sub(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
}
|
}
|
||||||
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
if (jcp_.layout == InterpolateLayoutType::by_channel) {
|
||||||
int dst_stride = step * jcp_.dst_data_size;
|
int dst_stride = vector_step * jcp_.dst_data_size;
|
||||||
int src_stride = step * jcp_.src_data_size;
|
int src_stride = vector_step * jcp_.src_data_size;
|
||||||
add(reg_dst, dst_stride);
|
add(reg_dst, dst_stride);
|
||||||
add(reg_src, src_stride);
|
add(reg_src, src_stride);
|
||||||
sub(reg_work_amount, step); // work_amount is c
|
sub(reg_work_amount, vector_step); // work_amount is c
|
||||||
} else {
|
} else {
|
||||||
int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size;
|
int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size;
|
||||||
int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size;
|
int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size;
|
||||||
@ -940,7 +951,6 @@ private:
|
|||||||
L(main_loop_end_label);
|
L(main_loop_end_label);
|
||||||
|
|
||||||
// only for by_channel layout for tails.
|
// only for by_channel layout for tails.
|
||||||
step = 1;
|
|
||||||
L(tail_loop_label);
|
L(tail_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, 1);
|
cmp(reg_work_amount, 1);
|
||||||
@ -953,15 +963,15 @@ private:
|
|||||||
|
|
||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value
|
apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value
|
||||||
add(reg_oc_off, step * sizeof(float));
|
add(reg_oc_off, scalar_step * sizeof(float));
|
||||||
}
|
}
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, scalar_step);
|
||||||
|
|
||||||
int dst_stride = step * jcp_.dst_data_size;
|
int dst_stride = scalar_step * jcp_.dst_data_size;
|
||||||
int src_stride = step * jcp_.src_data_size;
|
int src_stride = scalar_step * jcp_.src_data_size;
|
||||||
add(reg_dst, dst_stride);
|
add(reg_dst, dst_stride);
|
||||||
add(reg_src, src_stride);
|
add(reg_src, src_stride);
|
||||||
sub(reg_work_amount, step); // work_amount is c
|
sub(reg_work_amount, scalar_step); // work_amount is c
|
||||||
|
|
||||||
jmp(tail_loop_label, T_NEAR);
|
jmp(tail_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
@ -1020,7 +1030,6 @@ private:
|
|||||||
mov(reg_weight_y, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]);
|
mov(reg_weight_y, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]);
|
||||||
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
||||||
|
|
||||||
int step = vlen / sizeof(float);
|
|
||||||
int grid_len = 4;
|
int grid_len = 4;
|
||||||
|
|
||||||
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
|
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
|
||||||
@ -1035,7 +1044,7 @@ private:
|
|||||||
Xbyak::Label tail_loop_end_label;
|
Xbyak::Label tail_loop_end_label;
|
||||||
L(main_loop_label);
|
L(main_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(main_loop_end_label, T_NEAR);
|
jl(main_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
// vmm_tbl_y: (0 0 0 0 1 1 1 1 * index_size) --> (0 0 0 0 4 4 4 4)
|
// vmm_tbl_y: (0 0 0 0 1 1 1 1 * index_size) --> (0 0 0 0 4 4 4 4)
|
||||||
@ -1111,19 +1120,18 @@ private:
|
|||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel
|
apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel
|
||||||
}
|
}
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, vector_step);
|
||||||
|
|
||||||
add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence by dd()
|
add(reg_tbl_y, vector_step * sizeof(int)); // sizeof(int): sequence by dd()
|
||||||
add(reg_tbl_x, step * sizeof(int));
|
add(reg_tbl_x, vector_step * sizeof(int));
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, vector_step * jcp_.dst_data_size);
|
||||||
|
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, vector_step);
|
||||||
|
|
||||||
jmp(main_loop_label, T_NEAR);
|
jmp(main_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
L(main_loop_end_label);
|
L(main_loop_end_label);
|
||||||
|
|
||||||
step = 1;
|
|
||||||
L(tail_loop_label);
|
L(tail_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, 1);
|
cmp(reg_work_amount, 1);
|
||||||
@ -1182,13 +1190,13 @@ private:
|
|||||||
if (attr_.post_ops_.len() != 0) {
|
if (attr_.post_ops_.len() != 0) {
|
||||||
apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel
|
apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel
|
||||||
}
|
}
|
||||||
store(vmm_val, reg_dst, step);
|
store(vmm_val, reg_dst, scalar_step);
|
||||||
|
|
||||||
add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence with dd()
|
add(reg_tbl_y, scalar_step * sizeof(int)); // sizeof(int): sequence with dd()
|
||||||
add(reg_tbl_x, step * sizeof(int));
|
add(reg_tbl_x, scalar_step * sizeof(int));
|
||||||
add(reg_dst, step * jcp_.dst_data_size);
|
add(reg_dst, scalar_step * jcp_.dst_data_size);
|
||||||
|
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, scalar_step);
|
||||||
|
|
||||||
jmp(tail_loop_label, T_NEAR);
|
jmp(tail_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
@ -1264,7 +1272,7 @@ private:
|
|||||||
return ptr[reg_table + index * vlen];
|
return ptr[reg_table + index * vlen];
|
||||||
}
|
}
|
||||||
|
|
||||||
// always gather to Vmm, compute with Vmm, store with Xmm if scalar
|
// always gather to Vmm, compute with Vmm, store with Xmm if scalar_step
|
||||||
inline void gather_i32_indices(Vmm vmm_src, const Xbyak::Reg64 &base, int offset, Vmm vmm_indices, int scale,
|
inline void gather_i32_indices(Vmm vmm_src, const Xbyak::Reg64 &base, int offset, Vmm vmm_indices, int scale,
|
||||||
Precision src_prc, bool is_scalar) {
|
Precision src_prc, bool is_scalar) {
|
||||||
Xbyak::Address table_idx = ptr[base + offset + vmm_indices * scale];
|
Xbyak::Address table_idx = ptr[base + offset + vmm_indices * scale];
|
||||||
|
@ -110,7 +110,13 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generate() override {
|
void generate() override {
|
||||||
load_emitter.reset(new jit_load_emitter(this, isa));
|
tail_step = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / vector_step) * vector_step :
|
||||||
|
jcp_.C - (jcp_.C / vector_step) * vector_step;
|
||||||
|
|
||||||
|
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
|
||||||
|
load_vector_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, vector_step));
|
||||||
|
load_tail_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, tail_step));
|
||||||
|
load_tail_with_fill_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, tail_step, Precision::FP32, true));
|
||||||
|
|
||||||
this->preamble();
|
this->preamble();
|
||||||
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
|
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
|
||||||
@ -134,14 +140,11 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step :
|
|
||||||
jcp_.C - (jcp_.C / step) * step;
|
|
||||||
|
|
||||||
load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
|
load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
|
||||||
|
|
||||||
if (jcp_.planar_layout) {
|
if (jcp_.planar_layout) {
|
||||||
worker_unroll();
|
worker_unroll();
|
||||||
if (tail_num != 0) {
|
if (tail_step != 0) {
|
||||||
worker_tail_planar();
|
worker_tail_planar();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,7 +201,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
|
|||||||
}
|
}
|
||||||
|
|
||||||
Xbyak::Label label_empty_2half_sse42;
|
Xbyak::Label label_empty_2half_sse42;
|
||||||
if (tail_num == 0) {
|
if (tail_step == 0) {
|
||||||
cmp(reg_oc_off, static_cast<int>(jcp_.C * sizeof(float)));
|
cmp(reg_oc_off, static_cast<int>(jcp_.C * sizeof(float)));
|
||||||
jae(label_empty_2half_sse42, T_NEAR);
|
jae(label_empty_2half_sse42, T_NEAR);
|
||||||
|
|
||||||
@ -210,7 +213,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
|
|||||||
|
|
||||||
Xbyak::Label label_full_size;
|
Xbyak::Label label_full_size;
|
||||||
Xbyak::Label label_size_end;
|
Xbyak::Label label_size_end;
|
||||||
cmp(reg_oc_off, static_cast<int>((jcp_.C - step) * sizeof(float)));
|
cmp(reg_oc_off, static_cast<int>((jcp_.C - vector_step) * sizeof(float)));
|
||||||
jle(label_full_size, T_NEAR);
|
jle(label_full_size, T_NEAR);
|
||||||
|
|
||||||
// no need care and fill rest
|
// no need care and fill rest
|
||||||
@ -251,7 +254,9 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
|
|||||||
|
|
||||||
this->postamble();
|
this->postamble();
|
||||||
|
|
||||||
load_emitter->emit_data();
|
load_vector_emitter->emit_data();
|
||||||
|
load_tail_emitter->emit_data();
|
||||||
|
load_tail_with_fill_emitter->emit_data();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -259,8 +264,8 @@ private:
|
|||||||
Xbyak::Ymm, Xbyak::Zmm>::type;
|
Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||||
|
|
||||||
const int vlen = cpu_isa_traits<isa>::vlen;
|
const int vlen = cpu_isa_traits<isa>::vlen;
|
||||||
const int step = vlen / sizeof(float);
|
const int vector_step = vlen / sizeof(float);
|
||||||
int tail_num = 0;
|
int tail_step = 0;
|
||||||
|
|
||||||
Xbyak::Reg64 reg_src = r8;
|
Xbyak::Reg64 reg_src = r8;
|
||||||
Xbyak::Reg64 reg_mean = r9;
|
Xbyak::Reg64 reg_mean = r9;
|
||||||
@ -286,15 +291,15 @@ private:
|
|||||||
|
|
||||||
Xbyak::Opmask k_mask = Xbyak::Opmask(7);
|
Xbyak::Opmask k_mask = Xbyak::Opmask(7);
|
||||||
|
|
||||||
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
|
std::unique_ptr<jit_load_emitter> load_vector_emitter = nullptr;
|
||||||
|
std::unique_ptr<jit_load_emitter> load_tail_emitter = nullptr;
|
||||||
|
std::unique_ptr<jit_load_emitter> load_tail_with_fill_emitter = nullptr;
|
||||||
|
|
||||||
std::vector<size_t> load_pool_gpr_idxs;
|
std::vector<size_t> load_pool_gpr_idxs;
|
||||||
|
|
||||||
inline void worker_full_size() {
|
inline void worker_full_size() {
|
||||||
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
|
load_vector_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
{}, {load_pool_gpr_idxs});
|
||||||
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, step),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
|
|
||||||
if (jcp_.normalize_variance) {
|
if (jcp_.normalize_variance) {
|
||||||
// all with float
|
// all with float
|
||||||
@ -313,9 +318,7 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void worker_tail_blk() {
|
inline void worker_tail_blk() {
|
||||||
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
|
load_tail_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
|
||||||
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
{}, {load_pool_gpr_idxs});
|
||||||
|
|
||||||
if (jcp_.normalize_variance) {
|
if (jcp_.normalize_variance) {
|
||||||
@ -357,10 +360,8 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void worker_tail_planar() {
|
inline void worker_tail_planar() {
|
||||||
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
|
load_tail_with_fill_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
{}, {load_pool_gpr_idxs});
|
||||||
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num, 0, true),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
|
|
||||||
if (jcp_.normalize_variance) {
|
if (jcp_.normalize_variance) {
|
||||||
if (!isFloatCompatible(jcp_.src_prc))
|
if (!isFloatCompatible(jcp_.src_prc))
|
||||||
@ -371,15 +372,15 @@ private:
|
|||||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||||
if (isa == cpu::x64::sse41) {
|
if (isa == cpu::x64::sse41) {
|
||||||
uint8 imm = 1;
|
uint8 imm = 1;
|
||||||
imm = ~((imm << tail_num) - imm);
|
imm = ~((imm << tail_step) - imm);
|
||||||
blendps(vmm_val, vmm_zero, imm);
|
blendps(vmm_val, vmm_zero, imm);
|
||||||
} else if (isa == cpu::x64::avx2) {
|
} else if (isa == cpu::x64::avx2) {
|
||||||
uint8 imm = 1;
|
uint8 imm = 1;
|
||||||
imm = ~((imm << tail_num) - imm);
|
imm = ~((imm << tail_step) - imm);
|
||||||
vblendps(vmm_val, vmm_val, vmm_zero, imm);
|
vblendps(vmm_val, vmm_val, vmm_zero, imm);
|
||||||
} else if (isa == cpu::x64::avx512_core) {
|
} else if (isa == cpu::x64::avx512_core) {
|
||||||
uint64_t tail_mask = 1;
|
uint64_t tail_mask = 1;
|
||||||
tail_mask = ~((tail_mask << tail_num) - tail_mask);
|
tail_mask = ~((tail_mask << tail_step) - tail_mask);
|
||||||
mov(reg_aux, tail_mask);
|
mov(reg_aux, tail_mask);
|
||||||
kmovq(k_mask, reg_aux);
|
kmovq(k_mask, reg_aux);
|
||||||
vblendmps(vmm_val | k_mask, vmm_val, vmm_zero);
|
vblendmps(vmm_val | k_mask, vmm_val, vmm_zero);
|
||||||
@ -435,8 +436,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
load_emitter.reset(new jit_load_emitter(this, isa));
|
tail_step = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / vector_step) * vector_step :
|
||||||
store_emitter.reset(new jit_store_emitter(this, isa));
|
jcp_.C - (jcp_.C / vector_step) * vector_step;
|
||||||
|
|
||||||
|
load_vector_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, Precision::FP32, vector_step));
|
||||||
|
load_tail_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, Precision::FP32, tail_step));
|
||||||
|
store_vector_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, vector_step));
|
||||||
|
store_tail_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, tail_step));
|
||||||
|
|
||||||
this->preamble();
|
this->preamble();
|
||||||
|
|
||||||
@ -463,16 +469,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
|
|||||||
|
|
||||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||||
|
|
||||||
tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step :
|
|
||||||
jcp_.C - (jcp_.C / step) * step;
|
|
||||||
|
|
||||||
load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
|
load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
|
||||||
store_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx())};
|
store_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx())};
|
||||||
store_pool_vec_idxs = {static_cast<size_t>(vmm_zero.getIdx())};
|
store_pool_vec_idxs = {static_cast<size_t>(vmm_zero.getIdx()), static_cast<size_t>(vmm_val.getIdx())};
|
||||||
|
|
||||||
if (jcp_.planar_layout) {
|
if (jcp_.planar_layout) {
|
||||||
worker_mvn_unroll();
|
worker_mvn_unroll();
|
||||||
if (tail_num != 0) {
|
if (tail_step != 0) {
|
||||||
worker_mvn(true);
|
worker_mvn(true);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -501,7 +504,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
|
|||||||
}
|
}
|
||||||
|
|
||||||
Xbyak::Label label_empty_2half_sse42;
|
Xbyak::Label label_empty_2half_sse42;
|
||||||
if (tail_num == 0) {
|
if (tail_step == 0) {
|
||||||
cmp(reg_oc_off, static_cast<int>(jcp_.C * sizeof(float)));
|
cmp(reg_oc_off, static_cast<int>(jcp_.C * sizeof(float)));
|
||||||
jae(label_empty_2half_sse42, T_NEAR);
|
jae(label_empty_2half_sse42, T_NEAR);
|
||||||
worker_mvn_unroll();
|
worker_mvn_unroll();
|
||||||
@ -512,7 +515,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
|
|||||||
Xbyak::Label label_full_size_block;
|
Xbyak::Label label_full_size_block;
|
||||||
Xbyak::Label label_size_end;
|
Xbyak::Label label_size_end;
|
||||||
|
|
||||||
cmp(reg_oc_off, static_cast<int>((jcp_.C - step) * sizeof(float)));
|
cmp(reg_oc_off, static_cast<int>((jcp_.C - vector_step) * sizeof(float)));
|
||||||
jle(label_full_size_block, T_NEAR);
|
jle(label_full_size_block, T_NEAR);
|
||||||
|
|
||||||
worker_mvn_unroll(true);
|
worker_mvn_unroll(true);
|
||||||
@ -530,8 +533,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
|
|||||||
|
|
||||||
this->postamble();
|
this->postamble();
|
||||||
|
|
||||||
load_emitter->emit_data();
|
load_vector_emitter->emit_data();
|
||||||
store_emitter->emit_data();
|
load_tail_emitter->emit_data();
|
||||||
|
store_vector_emitter->emit_data();
|
||||||
|
store_tail_emitter->emit_data();
|
||||||
|
|
||||||
for (auto& inj : eltwise_injectors)
|
for (auto& inj : eltwise_injectors)
|
||||||
inj->prepare_table();
|
inj->prepare_table();
|
||||||
@ -542,8 +547,8 @@ private:
|
|||||||
Xbyak::Ymm, Xbyak::Zmm>::type;
|
Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||||
|
|
||||||
const int vlen = cpu_isa_traits<isa>::vlen;
|
const int vlen = cpu_isa_traits<isa>::vlen;
|
||||||
const int step = vlen / sizeof(float);
|
const int vector_step = vlen / sizeof(float);
|
||||||
int tail_num = 0;
|
int tail_step = 0;
|
||||||
|
|
||||||
Xbyak::Reg64 reg_src = r8;
|
Xbyak::Reg64 reg_src = r8;
|
||||||
Xbyak::Reg64 reg_mean = r9;
|
Xbyak::Reg64 reg_mean = r9;
|
||||||
@ -570,8 +575,10 @@ private:
|
|||||||
Vmm vmm_d_weights = Vmm(5);
|
Vmm vmm_d_weights = Vmm(5);
|
||||||
Vmm vmm_d_bias = Vmm(6);
|
Vmm vmm_d_bias = Vmm(6);
|
||||||
|
|
||||||
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
|
std::unique_ptr<jit_load_emitter> load_vector_emitter = nullptr;
|
||||||
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
std::unique_ptr<jit_load_emitter> load_tail_emitter = nullptr;
|
||||||
|
std::unique_ptr<jit_store_emitter> store_vector_emitter = nullptr;
|
||||||
|
std::unique_ptr<jit_store_emitter> store_tail_emitter = nullptr;
|
||||||
|
|
||||||
std::vector<std::shared_ptr<jit_uni_eltwise_injector_f32<isa>>> eltwise_injectors;
|
std::vector<std::shared_ptr<jit_uni_eltwise_injector_f32<isa>>> eltwise_injectors;
|
||||||
std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
|
std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
|
||||||
@ -582,9 +589,10 @@ private:
|
|||||||
std::vector<size_t> load_pool_gpr_idxs;
|
std::vector<size_t> load_pool_gpr_idxs;
|
||||||
|
|
||||||
inline void worker_mvn(bool is_tail) {
|
inline void worker_mvn(bool is_tail) {
|
||||||
int elt_num = is_tail ? tail_num : step;
|
const auto& load_emitter = is_tail ? load_tail_emitter : load_vector_emitter;
|
||||||
|
const auto& store_emitter = is_tail ? store_tail_emitter : store_vector_emitter;
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
|
||||||
std::make_shared<load_emitter_context>(jcp_.src_prc, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
{}, {load_pool_gpr_idxs});
|
||||||
|
|
||||||
uni_vsubps(vmm_val, vmm_val, vmm_mean);
|
uni_vsubps(vmm_val, vmm_val, vmm_mean);
|
||||||
@ -594,7 +602,6 @@ private:
|
|||||||
apply_post_ops(jcp_.dst_prc, jcp_.planar_layout);
|
apply_post_ops(jcp_.dst_prc, jcp_.planar_layout);
|
||||||
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_val.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
|
store_emitter->emit_code({static_cast<size_t>(vmm_val.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.dst_prc, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,8 +44,9 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generate() override {
|
void generate() override {
|
||||||
load_emitter.reset(new jit_load_emitter(this, isa));
|
load_vector_emitter.reset(new jit_load_emitter(this, isa, Precision::FP32, Precision::FP32, vector_step));
|
||||||
store_emitter.reset(new jit_store_emitter(this, isa));
|
load_scalar_emitter.reset(new jit_load_emitter(this, isa, Precision::FP32, Precision::FP32, scalar_step));
|
||||||
|
|
||||||
exp_injector.reset(new jit_uni_eltwise_injector_f32<isa>(this, dnnl::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f));
|
exp_injector.reset(new jit_uni_eltwise_injector_f32<isa>(this, dnnl::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f));
|
||||||
|
|
||||||
this->preamble();
|
this->preamble();
|
||||||
@ -137,8 +138,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator
|
|||||||
|
|
||||||
this->postamble();
|
this->postamble();
|
||||||
|
|
||||||
load_emitter->emit_data();
|
load_vector_emitter->emit_data();
|
||||||
store_emitter->emit_data();
|
load_scalar_emitter->emit_data();
|
||||||
|
|
||||||
prepare_table();
|
prepare_table();
|
||||||
exp_injector->prepare_table();
|
exp_injector->prepare_table();
|
||||||
@ -147,6 +148,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator
|
|||||||
private:
|
private:
|
||||||
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm, isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
|
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm, isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||||
uint32_t vlen = cpu_isa_traits<isa>::vlen;
|
uint32_t vlen = cpu_isa_traits<isa>::vlen;
|
||||||
|
const int vector_step = vlen / sizeof(float);
|
||||||
|
const int scalar_step = 1;
|
||||||
|
|
||||||
Xbyak::Reg64 reg_boxes_coord0 = r8;
|
Xbyak::Reg64 reg_boxes_coord0 = r8;
|
||||||
Xbyak::Reg64 reg_boxes_coord1 = r9;
|
Xbyak::Reg64 reg_boxes_coord1 = r9;
|
||||||
@ -172,8 +175,8 @@ private:
|
|||||||
|
|
||||||
Xbyak::Reg64 reg_params = abi_param1;
|
Xbyak::Reg64 reg_params = abi_param1;
|
||||||
|
|
||||||
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
|
std::unique_ptr<jit_load_emitter> load_vector_emitter = nullptr;
|
||||||
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
std::unique_ptr<jit_load_emitter> load_scalar_emitter = nullptr;
|
||||||
|
|
||||||
std::vector<size_t> store_pool_gpr_idxs;
|
std::vector<size_t> store_pool_gpr_idxs;
|
||||||
std::vector<size_t> store_pool_vec_idxs;
|
std::vector<size_t> store_pool_vec_idxs;
|
||||||
@ -205,25 +208,24 @@ private:
|
|||||||
std::shared_ptr<jit_uni_eltwise_injector_f32<isa>> exp_injector;
|
std::shared_ptr<jit_uni_eltwise_injector_f32<isa>> exp_injector;
|
||||||
|
|
||||||
inline void hard_nms() {
|
inline void hard_nms() {
|
||||||
int step = vlen / sizeof(float);
|
|
||||||
Xbyak::Label main_loop_label_hard;
|
Xbyak::Label main_loop_label_hard;
|
||||||
Xbyak::Label main_loop_end_label_hard;
|
Xbyak::Label main_loop_end_label_hard;
|
||||||
Xbyak::Label tail_loop_label_hard;
|
Xbyak::Label tail_loop_label_hard;
|
||||||
Xbyak::Label terminate_label_hard;
|
Xbyak::Label terminate_label_hard;
|
||||||
L(main_loop_label_hard);
|
L(main_loop_label_hard);
|
||||||
{
|
{
|
||||||
cmp(reg_boxes_num, step);
|
cmp(reg_boxes_num, vector_step);
|
||||||
jl(main_loop_end_label_hard, T_NEAR);
|
jl(main_loop_end_label_hard, T_NEAR);
|
||||||
|
|
||||||
sub(reg_boxes_coord0, step * sizeof(float));
|
sub(reg_boxes_coord0, vector_step * sizeof(float));
|
||||||
sub(reg_boxes_coord1, step * sizeof(float));
|
sub(reg_boxes_coord1, vector_step * sizeof(float));
|
||||||
sub(reg_boxes_coord2, step * sizeof(float));
|
sub(reg_boxes_coord2, vector_step * sizeof(float));
|
||||||
sub(reg_boxes_coord3, step * sizeof(float));
|
sub(reg_boxes_coord3, vector_step * sizeof(float));
|
||||||
|
|
||||||
// iou result is in vmm_temp3
|
// iou result is in vmm_temp3
|
||||||
iou(step);
|
iou(vector_step);
|
||||||
|
|
||||||
sub(reg_boxes_num, step);
|
sub(reg_boxes_num, vector_step);
|
||||||
|
|
||||||
suppressed_by_iou(false);
|
suppressed_by_iou(false);
|
||||||
|
|
||||||
@ -236,21 +238,20 @@ private:
|
|||||||
}
|
}
|
||||||
L(main_loop_end_label_hard);
|
L(main_loop_end_label_hard);
|
||||||
|
|
||||||
step = 1;
|
|
||||||
L(tail_loop_label_hard);
|
L(tail_loop_label_hard);
|
||||||
{
|
{
|
||||||
cmp(reg_boxes_num, 1);
|
cmp(reg_boxes_num, 1);
|
||||||
jl(terminate_label_hard, T_NEAR);
|
jl(terminate_label_hard, T_NEAR);
|
||||||
|
|
||||||
sub(reg_boxes_coord0, step * sizeof(float));
|
sub(reg_boxes_coord0, scalar_step * sizeof(float));
|
||||||
sub(reg_boxes_coord1, step * sizeof(float));
|
sub(reg_boxes_coord1, scalar_step * sizeof(float));
|
||||||
sub(reg_boxes_coord2, step * sizeof(float));
|
sub(reg_boxes_coord2, scalar_step * sizeof(float));
|
||||||
sub(reg_boxes_coord3, step * sizeof(float));
|
sub(reg_boxes_coord3, scalar_step * sizeof(float));
|
||||||
|
|
||||||
// iou result is in vmm_temp3
|
// iou result is in vmm_temp3
|
||||||
iou(step);
|
iou(scalar_step);
|
||||||
|
|
||||||
sub(reg_boxes_num, step);
|
sub(reg_boxes_num, scalar_step);
|
||||||
|
|
||||||
suppressed_by_iou(true);
|
suppressed_by_iou(true);
|
||||||
|
|
||||||
@ -267,7 +268,6 @@ private:
|
|||||||
inline void soft_nms() {
|
inline void soft_nms() {
|
||||||
uni_vbroadcastss(vmm_scale, ptr[reg_scale]);
|
uni_vbroadcastss(vmm_scale, ptr[reg_scale]);
|
||||||
|
|
||||||
int step = vlen / sizeof(float);
|
|
||||||
Xbyak::Label main_loop_label;
|
Xbyak::Label main_loop_label;
|
||||||
Xbyak::Label main_loop_end_label;
|
Xbyak::Label main_loop_end_label;
|
||||||
Xbyak::Label tail_loop_label;
|
Xbyak::Label tail_loop_label;
|
||||||
@ -277,17 +277,17 @@ private:
|
|||||||
Xbyak::Label tail_loop_label_soft;
|
Xbyak::Label tail_loop_label_soft;
|
||||||
L(main_loop_label);
|
L(main_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_boxes_num, step);
|
cmp(reg_boxes_num, vector_step);
|
||||||
jl(main_loop_end_label, T_NEAR);
|
jl(main_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
sub(reg_boxes_coord0, step * sizeof(float));
|
sub(reg_boxes_coord0, vector_step * sizeof(float));
|
||||||
sub(reg_boxes_coord1, step * sizeof(float));
|
sub(reg_boxes_coord1, vector_step * sizeof(float));
|
||||||
sub(reg_boxes_coord2, step * sizeof(float));
|
sub(reg_boxes_coord2, vector_step * sizeof(float));
|
||||||
sub(reg_boxes_coord3, step * sizeof(float));
|
sub(reg_boxes_coord3, vector_step * sizeof(float));
|
||||||
|
|
||||||
// result(iou and weight) is in vmm_temp3
|
// result(iou and weight) is in vmm_temp3
|
||||||
iou(step);
|
iou(vector_step);
|
||||||
sub(reg_boxes_num, step);
|
sub(reg_boxes_num, vector_step);
|
||||||
|
|
||||||
// soft suppressed by iou_threshold
|
// soft suppressed by iou_threshold
|
||||||
if (jcp.is_soft_suppressed_by_iou) {
|
if (jcp.is_soft_suppressed_by_iou) {
|
||||||
@ -327,19 +327,18 @@ private:
|
|||||||
}
|
}
|
||||||
L(main_loop_end_label);
|
L(main_loop_end_label);
|
||||||
|
|
||||||
step = 1;
|
|
||||||
L(tail_loop_label);
|
L(tail_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_boxes_num, 1);
|
cmp(reg_boxes_num, 1);
|
||||||
jl(terminate_label, T_NEAR);
|
jl(terminate_label, T_NEAR);
|
||||||
|
|
||||||
sub(reg_boxes_coord0, step * sizeof(float));
|
sub(reg_boxes_coord0, scalar_step * sizeof(float));
|
||||||
sub(reg_boxes_coord1, step * sizeof(float));
|
sub(reg_boxes_coord1, scalar_step * sizeof(float));
|
||||||
sub(reg_boxes_coord2, step * sizeof(float));
|
sub(reg_boxes_coord2, scalar_step * sizeof(float));
|
||||||
sub(reg_boxes_coord3, step * sizeof(float));
|
sub(reg_boxes_coord3, scalar_step * sizeof(float));
|
||||||
|
|
||||||
iou(step);
|
iou(scalar_step);
|
||||||
sub(reg_boxes_num, step);
|
sub(reg_boxes_num, scalar_step);
|
||||||
|
|
||||||
// soft suppressed by iou_threshold
|
// soft suppressed by iou_threshold
|
||||||
if (jcp.is_soft_suppressed_by_iou) {
|
if (jcp.is_soft_suppressed_by_iou) {
|
||||||
@ -427,8 +426,11 @@ private:
|
|||||||
|
|
||||||
inline void iou(int ele_num) {
|
inline void iou(int ele_num) {
|
||||||
auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_dst) {
|
auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_dst) {
|
||||||
|
if (ele_num != scalar_step && ele_num != vector_step)
|
||||||
|
IE_THROW() << "NMS JIT implementation supports load emitter with only element count scalar_step or vector_step! Get: " << ele_num;
|
||||||
|
|
||||||
|
const auto& load_emitter = ele_num == 1 ? load_scalar_emitter : load_vector_emitter;
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_dst.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_dst.getIdx())},
|
||||||
std::make_shared<load_emitter_context>(Precision::FP32, Precision::FP32, ele_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
{}, {load_pool_gpr_idxs});
|
||||||
};
|
};
|
||||||
load(reg_boxes_coord0, vmm_boxes_coord0);
|
load(reg_boxes_coord0, vmm_boxes_coord0);
|
||||||
|
@ -46,9 +46,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji
|
|||||||
};
|
};
|
||||||
|
|
||||||
void generate() override {
|
void generate() override {
|
||||||
load_emitter.reset(new jit_load_emitter(this, isa));
|
|
||||||
store_emitter.reset(new jit_store_emitter(this, isa));
|
|
||||||
|
|
||||||
this->preamble();
|
this->preamble();
|
||||||
|
|
||||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||||
@ -65,8 +62,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji
|
|||||||
|
|
||||||
this->postamble();
|
this->postamble();
|
||||||
|
|
||||||
load_emitter->emit_data();
|
emit_emitters_data();
|
||||||
store_emitter->emit_data();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -107,10 +103,9 @@ private:
|
|||||||
// [1] for reg_dst
|
// [1] for reg_dst
|
||||||
Xmm xmm_args_pool = Xmm(15);
|
Xmm xmm_args_pool = Xmm(15);
|
||||||
|
|
||||||
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
|
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
|
||||||
std::vector<size_t> load_pool_gpr_idxs;
|
|
||||||
|
|
||||||
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
std::vector<size_t> load_pool_gpr_idxs;
|
||||||
std::vector<size_t> store_pool_gpr_idxs;
|
std::vector<size_t> store_pool_gpr_idxs;
|
||||||
std::vector<size_t> store_pool_vec_idxs;
|
std::vector<size_t> store_pool_vec_idxs;
|
||||||
|
|
||||||
@ -157,6 +152,57 @@ private:
|
|||||||
|
|
||||||
reg64_t reg_params = abi_param1;
|
reg64_t reg_params = abi_param1;
|
||||||
|
|
||||||
|
void emit_emitters_data() {
|
||||||
|
for (const auto& emitter : emitters) {
|
||||||
|
emitter.second->emit_data();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
|
emit_load(reg_src, vmm_src, jcp_.data_prc, Precision::FP32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void load_buffer(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
|
emit_load(reg_src, vmm_src, Precision::FP32, Precision::FP32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void load_idx(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
|
emit_load(reg_src, vmm_src, Precision::I32, Precision::I32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||||
|
emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.data_prc, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void store_buffer(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||||
|
emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::FP32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
|
||||||
|
const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash();
|
||||||
|
if (!emitters[seed]) {
|
||||||
|
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num));
|
||||||
|
}
|
||||||
|
|
||||||
|
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), static_cast<size_t>(offset)},
|
||||||
|
{static_cast<size_t>(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void emit_store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
|
||||||
|
const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash();
|
||||||
|
if (!emitters[seed]) {
|
||||||
|
emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num));
|
||||||
|
}
|
||||||
|
|
||||||
|
// for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm
|
||||||
|
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm_dst.getIdx()) };
|
||||||
|
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
|
||||||
|
|
||||||
|
emitters[seed]->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(offset)},
|
||||||
|
{static_cast<size_t>(reg_dst.getIdx())},
|
||||||
|
{local_store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
||||||
|
}
|
||||||
|
|
||||||
void roi_align_cgather() {
|
void roi_align_cgather() {
|
||||||
mov(reg_src_address, ptr[reg_params + GET_OFF(src)]);
|
mov(reg_src_address, ptr[reg_params + GET_OFF(src)]);
|
||||||
mov(reg_weights, ptr[reg_params + GET_OFF(weights)]);
|
mov(reg_weights, ptr[reg_params + GET_OFF(weights)]);
|
||||||
@ -180,23 +226,6 @@ private:
|
|||||||
imul(reg_src_stride, reg_src_stride, jcp_.data_size);
|
imul(reg_src_stride, reg_src_stride, jcp_.data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto store = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) {
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
|
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.data_prc, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
};
|
|
||||||
|
|
||||||
auto load_buf = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) {
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
|
|
||||||
std::make_shared<load_emitter_context>(Precision::FP32, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
};
|
|
||||||
auto store_buf = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) {
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
|
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, Precision::FP32, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
};
|
|
||||||
|
|
||||||
// out loop for samples in bin
|
// out loop for samples in bin
|
||||||
Xbyak::Label out_loop_label;
|
Xbyak::Label out_loop_label;
|
||||||
Xbyak::Label out_loop_end_label;
|
Xbyak::Label out_loop_end_label;
|
||||||
@ -228,13 +257,13 @@ private:
|
|||||||
generate_samples(v_step);
|
generate_samples(v_step);
|
||||||
// now this sample value across channel reside in vmm_sample
|
// now this sample value across channel reside in vmm_sample
|
||||||
// compute with other samples in vmm_buf
|
// compute with other samples in vmm_buf
|
||||||
load_buf(reg_buf, vmm_buf, v_step);
|
load_buffer(reg_buf, vmm_buf, v_step);
|
||||||
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
||||||
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
|
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
|
||||||
} else {
|
} else {
|
||||||
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
|
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
|
||||||
}
|
}
|
||||||
store_buf(vmm_buf, reg_buf, v_step);
|
store_buffer(vmm_buf, reg_buf, v_step);
|
||||||
|
|
||||||
if ((isa == cpu::x64::sse41) && (jcp_.layout == ROIAlignLayoutType::blk)) {
|
if ((isa == cpu::x64::sse41) && (jcp_.layout == ROIAlignLayoutType::blk)) {
|
||||||
add(reg_src0, x_step * jcp_.data_size);
|
add(reg_src0, x_step * jcp_.data_size);
|
||||||
@ -244,13 +273,13 @@ private:
|
|||||||
add(reg_buf, x_step * sizeof(float));
|
add(reg_buf, x_step * sizeof(float));
|
||||||
|
|
||||||
generate_samples(x_step);
|
generate_samples(x_step);
|
||||||
load_buf(reg_buf, vmm_buf, x_step);
|
load_buffer(reg_buf, vmm_buf, x_step);
|
||||||
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
||||||
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
|
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
|
||||||
} else {
|
} else {
|
||||||
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
|
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
|
||||||
}
|
}
|
||||||
store_buf(vmm_buf, reg_buf, x_step);
|
store_buffer(vmm_buf, reg_buf, x_step);
|
||||||
|
|
||||||
sub(reg_src0, x_step * jcp_.data_size);
|
sub(reg_src0, x_step * jcp_.data_size);
|
||||||
sub(reg_src1, x_step * jcp_.data_size);
|
sub(reg_src1, x_step * jcp_.data_size);
|
||||||
@ -280,13 +309,13 @@ private:
|
|||||||
jl(in_loop_tail_end_label, T_NEAR);
|
jl(in_loop_tail_end_label, T_NEAR);
|
||||||
|
|
||||||
generate_samples(tail_step);
|
generate_samples(tail_step);
|
||||||
load_buf(reg_buf, vmm_buf, tail_step);
|
load_buffer(reg_buf, vmm_buf, tail_step);
|
||||||
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
||||||
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
|
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
|
||||||
} else {
|
} else {
|
||||||
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
|
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
|
||||||
}
|
}
|
||||||
store_buf(vmm_buf, reg_buf, tail_step);
|
store_buffer(vmm_buf, reg_buf, tail_step);
|
||||||
|
|
||||||
int tail_src_stride = tail_step * jcp_.data_size;
|
int tail_src_stride = tail_step * jcp_.data_size;
|
||||||
add(reg_src0, tail_src_stride);
|
add(reg_src0, tail_src_stride);
|
||||||
@ -333,7 +362,7 @@ private:
|
|||||||
cmp(reg_work_amount, v_step);
|
cmp(reg_work_amount, v_step);
|
||||||
jl(store_loop_main_end_label, T_NEAR);
|
jl(store_loop_main_end_label, T_NEAR);
|
||||||
|
|
||||||
load_buf(reg_buf, vmm_buf, v_step);
|
load_buffer(reg_buf, vmm_buf, v_step);
|
||||||
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
||||||
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
|
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
|
||||||
}
|
}
|
||||||
@ -343,7 +372,7 @@ private:
|
|||||||
add(reg_buf, x_step * sizeof(float));
|
add(reg_buf, x_step * sizeof(float));
|
||||||
add(reg_dst, x_step * jcp_.data_size);
|
add(reg_dst, x_step * jcp_.data_size);
|
||||||
|
|
||||||
load_buf(reg_buf, vmm_buf, x_step);
|
load_buffer(reg_buf, vmm_buf, x_step);
|
||||||
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
||||||
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
|
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
|
||||||
}
|
}
|
||||||
@ -369,7 +398,7 @@ private:
|
|||||||
cmp(reg_work_amount, tail_step);
|
cmp(reg_work_amount, tail_step);
|
||||||
jl(store_loop_tail_end_label, T_NEAR);
|
jl(store_loop_tail_end_label, T_NEAR);
|
||||||
|
|
||||||
load_buf(reg_buf, vmm_buf, tail_step);
|
load_buffer(reg_buf, vmm_buf, tail_step);
|
||||||
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
if (jcp_.alg == Algorithm::ROIAlignAvg) {
|
||||||
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
|
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
|
||||||
}
|
}
|
||||||
@ -402,12 +431,6 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generate_samples(int num) {
|
void generate_samples(int num) {
|
||||||
auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) {
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
|
|
||||||
std::make_shared<load_emitter_context>(jcp_.data_prc, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
};
|
|
||||||
|
|
||||||
uni_vpxor(vmm_sample, vmm_sample, vmm_sample);
|
uni_vpxor(vmm_sample, vmm_sample, vmm_sample);
|
||||||
load(reg_src0, vmm_src0, num);
|
load(reg_src0, vmm_src0, num);
|
||||||
uni_vfmadd231ps(vmm_sample, vmm_src0, vmm_weights0);
|
uni_vfmadd231ps(vmm_sample, vmm_src0, vmm_weights0);
|
||||||
@ -432,12 +455,6 @@ private:
|
|||||||
uni_vbroadcastss(vmm_scale, ptr[reg_tmp_64]);
|
uni_vbroadcastss(vmm_scale, ptr[reg_tmp_64]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto load_idx = [&](Xbyak::Reg64 reg_idx, Vmm vmm_idx, int elt_num) {
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_idx.getIdx())}, {static_cast<size_t>(vmm_idx.getIdx())},
|
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
};
|
|
||||||
|
|
||||||
Xbyak::Label main_loop_label;
|
Xbyak::Label main_loop_label;
|
||||||
Xbyak::Label main_loop_end_label;
|
Xbyak::Label main_loop_end_label;
|
||||||
Xbyak::Label tail_loop_label;
|
Xbyak::Label tail_loop_label;
|
||||||
|
@ -48,8 +48,9 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
|
|||||||
};
|
};
|
||||||
|
|
||||||
void generate() override {
|
void generate() override {
|
||||||
load_emitter.reset(new jit_load_emitter(this, isa));
|
load_emitter.reset(new jit_load_emitter(this, isa, jpp_.src_prc, Precision::FP32, step));
|
||||||
store_emitter.reset(new jit_store_emitter(this, isa));
|
store_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jpp_.dst_prc, step));
|
||||||
|
store_empty_roi_emitter.reset(new jit_store_emitter(this, isa, jpp_.src_prc, jpp_.dst_prc, step));
|
||||||
|
|
||||||
this->preamble();
|
this->preamble();
|
||||||
|
|
||||||
@ -93,6 +94,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
|
|||||||
|
|
||||||
load_emitter->emit_data();
|
load_emitter->emit_data();
|
||||||
store_emitter->emit_data();
|
store_emitter->emit_data();
|
||||||
|
store_empty_roi_emitter->emit_data();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -114,6 +116,7 @@ private:
|
|||||||
std::vector<size_t> load_pool_gpr_idxs;
|
std::vector<size_t> load_pool_gpr_idxs;
|
||||||
|
|
||||||
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
||||||
|
std::unique_ptr<jit_store_emitter> store_empty_roi_emitter = nullptr;
|
||||||
std::vector<size_t> store_pool_gpr_idxs;
|
std::vector<size_t> store_pool_gpr_idxs;
|
||||||
std::vector<size_t> store_pool_vec_idxs;
|
std::vector<size_t> store_pool_vec_idxs;
|
||||||
|
|
||||||
@ -147,6 +150,12 @@ private:
|
|||||||
Xbyak::Reg64 reg_load_table = r15;
|
Xbyak::Reg64 reg_load_table = r15;
|
||||||
Xbyak::Reg64 reg_load_store_mask = abi_param1;
|
Xbyak::Reg64 reg_load_store_mask = abi_param1;
|
||||||
|
|
||||||
|
std::vector<size_t> get_local_store_pool_vec_idxs(Vmm vmm) const {
|
||||||
|
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm.getIdx()) };
|
||||||
|
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
|
||||||
|
return local_store_pool_vec_idxs;
|
||||||
|
}
|
||||||
|
|
||||||
void roi_pool_max(int c_blocks) {
|
void roi_pool_max(int c_blocks) {
|
||||||
Label h_loop_label;
|
Label h_loop_label;
|
||||||
Label w_loop_label;
|
Label w_loop_label;
|
||||||
@ -157,8 +166,7 @@ private:
|
|||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
Vmm vmm_max = get_acc_reg(i);
|
Vmm vmm_max = get_acc_reg(i);
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx())}, {static_cast<size_t>(vmm_max.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx()), static_cast<size_t>(i * src_c_off)}, {static_cast<size_t>(vmm_max.getIdx())},
|
||||||
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
|
|
||||||
{}, load_pool_gpr_idxs);
|
{}, load_pool_gpr_idxs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,9 +179,8 @@ private:
|
|||||||
Vmm vmm_max = get_acc_reg(i);
|
Vmm vmm_max = get_acc_reg(i);
|
||||||
Vmm vmm_src = get_src_reg(i);
|
Vmm vmm_src = get_src_reg(i);
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx()), static_cast<size_t>(i * src_c_off)},
|
||||||
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
|
{static_cast<size_t>(vmm_src.getIdx())}, {}, load_pool_gpr_idxs);
|
||||||
{}, load_pool_gpr_idxs);
|
|
||||||
|
|
||||||
if (isa == cpu::x64::sse41) {
|
if (isa == cpu::x64::sse41) {
|
||||||
movups(vmm_mask, vmm_max);
|
movups(vmm_mask, vmm_max);
|
||||||
@ -206,9 +213,8 @@ private:
|
|||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
Vmm vmm_dst = get_acc_reg(i);
|
Vmm vmm_dst = get_acc_reg(i);
|
||||||
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(i * dst_c_off)}, {static_cast<size_t>(reg_output.getIdx())},
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, i * dst_c_off),
|
get_local_store_pool_vec_idxs(vmm_dst), store_pool_gpr_idxs);
|
||||||
store_pool_vec_idxs, store_pool_gpr_idxs);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,27 +231,22 @@ private:
|
|||||||
|
|
||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
|
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
|
||||||
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
|
|
||||||
|
|
||||||
mov(aux_reg_input, reg_input);
|
mov(aux_reg_input, reg_input);
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src00.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src00.getIdx())},
|
||||||
load_context,
|
|
||||||
{}, load_pool_gpr_idxs);
|
{}, load_pool_gpr_idxs);
|
||||||
add(aux_reg_input, reg_xoff);
|
add(aux_reg_input, reg_xoff);
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src01.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src01.getIdx())},
|
||||||
load_context,
|
|
||||||
{}, load_pool_gpr_idxs);
|
{}, load_pool_gpr_idxs);
|
||||||
|
|
||||||
add(aux_reg_input, reg_yoff);
|
add(aux_reg_input, reg_yoff);
|
||||||
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src11.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src11.getIdx())},
|
||||||
load_context,
|
|
||||||
{}, load_pool_gpr_idxs);
|
{}, load_pool_gpr_idxs);
|
||||||
sub(aux_reg_input, reg_xoff);
|
sub(aux_reg_input, reg_xoff);
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src10.getIdx())},
|
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src10.getIdx())},
|
||||||
load_context,
|
|
||||||
{}, load_pool_gpr_idxs);
|
{}, load_pool_gpr_idxs);
|
||||||
|
|
||||||
uni_vsubps(vmm_src01, vmm_src01, vmm_src00);
|
uni_vsubps(vmm_src01, vmm_src01, vmm_src00);
|
||||||
@ -259,9 +260,8 @@ private:
|
|||||||
|
|
||||||
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||||
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx()), static_cast<size_t>(dst_c_off)}, {static_cast<size_t>(reg_output.getIdx())},
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
|
get_local_store_pool_vec_idxs(vmm_src11), store_pool_gpr_idxs);
|
||||||
store_pool_vec_idxs, store_pool_gpr_idxs);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,9 +270,8 @@ private:
|
|||||||
|
|
||||||
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
|
||||||
for (int i = 0; i < c_blocks; i++) {
|
for (int i = 0; i < c_blocks; i++) {
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
|
store_empty_roi_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx()), static_cast<size_t>(i * dst_c_off)},
|
||||||
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
|
{static_cast<size_t>(reg_output.getIdx())}, store_pool_vec_idxs, store_pool_gpr_idxs);
|
||||||
store_pool_vec_idxs, store_pool_gpr_idxs);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,9 +82,6 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generate() override {
|
void generate() override {
|
||||||
load_emitter.reset(new jit_load_emitter(this, isa));
|
|
||||||
store_emitter.reset(new jit_store_emitter(this, isa));
|
|
||||||
|
|
||||||
this->preamble();
|
this->preamble();
|
||||||
|
|
||||||
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
|
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
|
||||||
@ -123,8 +120,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato
|
|||||||
|
|
||||||
this->postamble();
|
this->postamble();
|
||||||
|
|
||||||
load_emitter->emit_data();
|
emit_emitters_data();
|
||||||
store_emitter->emit_data();
|
|
||||||
|
|
||||||
if (!shape_agnostic_alg)
|
if (!shape_agnostic_alg)
|
||||||
prepare_idx_table();
|
prepare_idx_table();
|
||||||
@ -207,9 +203,8 @@ private:
|
|||||||
Vmm vmm_zero = Vmm(0); // vmm_zero represents Vmm(0) when isa is avx512_core, otherwise vmm_mask represents Vmm(0)
|
Vmm vmm_zero = Vmm(0); // vmm_zero represents Vmm(0) when isa is avx512_core, otherwise vmm_mask represents Vmm(0)
|
||||||
|
|
||||||
const Xbyak::Opmask k_mask = Xbyak::Opmask(1);
|
const Xbyak::Opmask k_mask = Xbyak::Opmask(1);
|
||||||
const int step = vlen / sizeof(float);
|
const int vector_step = vlen / sizeof(float);
|
||||||
const int tail = jcp_.work_amount % step;
|
const int tail_step = jcp_.work_amount % vector_step;
|
||||||
const int topk_tail = jcp_.top_k % step;
|
|
||||||
|
|
||||||
int blk_stride = 0; // stride of channel blocks at the same space coordinate, only used in blocked layout with topk on channel
|
int blk_stride = 0; // stride of channel blocks at the same space coordinate, only used in blocked layout with topk on channel
|
||||||
unsigned char cmp_flg;
|
unsigned char cmp_flg;
|
||||||
@ -217,13 +212,67 @@ private:
|
|||||||
|
|
||||||
Xbyak::Label l_table;
|
Xbyak::Label l_table;
|
||||||
|
|
||||||
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
|
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
|
||||||
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
|
||||||
|
|
||||||
std::vector<size_t> store_pool_gpr_idxs;
|
std::vector<size_t> store_pool_gpr_idxs;
|
||||||
std::vector<size_t> load_pool_gpr_idxs;
|
std::vector<size_t> load_pool_gpr_idxs;
|
||||||
std::vector<size_t> store_pool_vec_idxs;
|
std::vector<size_t> store_pool_vec_idxs;
|
||||||
|
|
||||||
|
void emit_emitters_data() {
|
||||||
|
for (const auto& emitter : emitters) {
|
||||||
|
emitter.second->emit_data();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
|
emit_load(reg_src, vmm_src, jcp_.precision, Precision::FP32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void load_i32_f32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
|
emit_load(reg_src, vmm_src, Precision::I32, Precision::FP32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void load_i32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||||
|
emit_load(reg_src, vmm_src, Precision::I32, Precision::I32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||||
|
emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.precision, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void store_f32_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||||
|
emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::I32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void store_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||||
|
emit_store(vmm_dst, reg_dst, Precision::I32, Precision::I32, elt_num, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
|
||||||
|
const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash();
|
||||||
|
if (!emitters[seed]) {
|
||||||
|
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num));
|
||||||
|
}
|
||||||
|
|
||||||
|
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), static_cast<size_t>(offset)},
|
||||||
|
{static_cast<size_t>(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void emit_store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
|
||||||
|
const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash();
|
||||||
|
if (!emitters[seed]) {
|
||||||
|
emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num));
|
||||||
|
}
|
||||||
|
|
||||||
|
// for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm
|
||||||
|
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm_dst.getIdx()) };
|
||||||
|
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
|
||||||
|
|
||||||
|
emitters[seed]->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(offset)},
|
||||||
|
{static_cast<size_t>(reg_dst.getIdx())},
|
||||||
|
{local_store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
||||||
|
}
|
||||||
|
|
||||||
inline void topk_loop() {
|
inline void topk_loop() {
|
||||||
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
|
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
|
||||||
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
|
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
|
||||||
@ -253,27 +302,27 @@ private:
|
|||||||
Xbyak::Label topk_main_loop_end_label;
|
Xbyak::Label topk_main_loop_end_label;
|
||||||
L(topk_main_loop_label);
|
L(topk_main_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(topk_main_loop_end_label, T_NEAR);
|
jl(topk_main_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
topk_bitonic(step);
|
topk_bitonic(vector_step);
|
||||||
|
|
||||||
add(reg_src, step * jcp_.data_size);
|
add(reg_src, vector_step * jcp_.data_size);
|
||||||
add(reg_dst, step * jcp_.data_size);
|
add(reg_dst, vector_step * jcp_.data_size);
|
||||||
add(reg_dst_idx, step * sizeof(int));
|
add(reg_dst_idx, vector_step * sizeof(int));
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, vector_step);
|
||||||
|
|
||||||
jmp(topk_main_loop_label, T_NEAR);
|
jmp(topk_main_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
L(topk_main_loop_end_label);
|
L(topk_main_loop_end_label);
|
||||||
|
|
||||||
// tail
|
// tail
|
||||||
if (tail) {
|
if (tail_step) {
|
||||||
Xbyak::Label topk_tail_loop_end_label;
|
Xbyak::Label topk_tail_loop_end_label;
|
||||||
cmp(reg_work_amount, tail);
|
cmp(reg_work_amount, tail_step);
|
||||||
jl(topk_tail_loop_end_label, T_NEAR);
|
jl(topk_tail_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
topk_bitonic(tail);
|
topk_bitonic(tail_step);
|
||||||
|
|
||||||
L(topk_tail_loop_end_label);
|
L(topk_tail_loop_end_label);
|
||||||
}
|
}
|
||||||
@ -282,19 +331,11 @@ private:
|
|||||||
inline void topk_bitonic(int elt_num) {
|
inline void topk_bitonic(int elt_num) {
|
||||||
// src => prc
|
// src => prc
|
||||||
for (int i = 0; i < jcp_.axis_dim; i++) {
|
for (int i = 0; i < jcp_.axis_dim; i++) {
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
|
load(reg_src, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
|
store(vmm_tmp, reg_prc, elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_prc.getIdx())},
|
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_table.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
|
load_i32(reg_table, vmm_tmp, elt_num, i * vlen);
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, elt_num, i * vlen),
|
store_i32(vmm_tmp, reg_prc_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_prc_idx.getIdx())},
|
|
||||||
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sort
|
// sort
|
||||||
@ -305,19 +346,11 @@ private:
|
|||||||
|
|
||||||
// prc => dst
|
// prc => dst
|
||||||
for (int i = 0; i < jcp_.top_k; i++) {
|
for (int i = 0; i < jcp_.top_k; i++) {
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_prc.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
|
load(reg_prc, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
|
store(vmm_tmp, reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
|
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
|
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_prc_idx.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
|
load_i32(reg_prc_idx, vmm_tmp, elt_num, i * jcp_.sort_stride * sizeof(int));
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
|
store_i32(vmm_tmp, reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_dst_idx.getIdx())},
|
|
||||||
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -330,46 +363,46 @@ private:
|
|||||||
Xbyak::Label topk_main_loop_end_label;
|
Xbyak::Label topk_main_loop_end_label;
|
||||||
L(topk_main_loop_label);
|
L(topk_main_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(topk_main_loop_end_label, T_NEAR);
|
jl(topk_main_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
// src => prc
|
// src => prc
|
||||||
bitonic_BLK_on_channel_load(step);
|
bitonic_BLK_on_channel_load(vector_step);
|
||||||
|
|
||||||
// sort
|
// sort
|
||||||
bitonic_sort_vector(step);
|
bitonic_sort_vector(vector_step);
|
||||||
if (jcp_.sort_index) {
|
if (jcp_.sort_index) {
|
||||||
bitonic_sort_vector(step, false);
|
bitonic_sort_vector(vector_step, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// prc => dst
|
// prc => dst
|
||||||
bitonic_BLK_on_channel_store(step);
|
bitonic_BLK_on_channel_store(vector_step);
|
||||||
|
|
||||||
add(reg_src, step * jcp_.blk_size * jcp_.data_size);
|
add(reg_src, vector_step * jcp_.blk_size * jcp_.data_size);
|
||||||
add(reg_dst, step * jcp_.blk_size * jcp_.data_size);
|
add(reg_dst, vector_step * jcp_.blk_size * jcp_.data_size);
|
||||||
add(reg_dst_idx, step * jcp_.blk_size * sizeof(int));
|
add(reg_dst_idx, vector_step * jcp_.blk_size * sizeof(int));
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, vector_step);
|
||||||
|
|
||||||
jmp(topk_main_loop_label, T_NEAR);
|
jmp(topk_main_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
L(topk_main_loop_end_label);
|
L(topk_main_loop_end_label);
|
||||||
|
|
||||||
// tail exists because working buffer has planar layout, though source buffer has blocked layout)
|
// tail exists because working buffer has planar layout, though source buffer has blocked layout)
|
||||||
if (tail) {
|
if (tail_step) {
|
||||||
Xbyak::Label topk_tail_loop_end_label;
|
Xbyak::Label topk_tail_loop_end_label;
|
||||||
cmp(reg_work_amount, tail);
|
cmp(reg_work_amount, tail_step);
|
||||||
jl(topk_tail_loop_end_label, T_NEAR);
|
jl(topk_tail_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
// src => prc
|
// src => prc
|
||||||
bitonic_BLK_on_channel_load(tail);
|
bitonic_BLK_on_channel_load(tail_step);
|
||||||
|
|
||||||
bitonic_sort_vector(tail);
|
bitonic_sort_vector(tail_step);
|
||||||
if (jcp_.sort_index) {
|
if (jcp_.sort_index) {
|
||||||
bitonic_sort_vector(tail, false);
|
bitonic_sort_vector(tail_step, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// prc => dst
|
// prc => dst
|
||||||
bitonic_BLK_on_channel_store(tail);
|
bitonic_BLK_on_channel_store(tail_step);
|
||||||
|
|
||||||
L(topk_tail_loop_end_label);
|
L(topk_tail_loop_end_label);
|
||||||
}
|
}
|
||||||
@ -437,40 +470,30 @@ private:
|
|||||||
|
|
||||||
inline void bitonic_swap_vector(int elt_num, bool cmp_val = true) {
|
inline void bitonic_swap_vector(int elt_num, bool cmp_val = true) {
|
||||||
bitonic_get_addr(reg_prc, jcp_.data_size, 0);
|
bitonic_get_addr(reg_prc, jcp_.data_size, 0);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_val_l.getIdx())},
|
load(reg_aux_idx, vmm_val_l, elt_num);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int));
|
bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int));
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_val_r.getIdx())},
|
load(reg_aux_idx, vmm_val_r, elt_num);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
|
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_idx_l.getIdx())},
|
load_i32_f32(reg_aux_idx, vmm_idx_l, elt_num);
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
|
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_idx_r.getIdx())},
|
load_i32_f32(reg_aux_idx, vmm_idx_r, elt_num);
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
|
|
||||||
swap_vector(vmm_val_l, vmm_idx_l, vmm_val_r, vmm_idx_r, cmp_val);
|
swap_vector(vmm_val_l, vmm_idx_l, vmm_val_r, vmm_idx_r, cmp_val);
|
||||||
|
|
||||||
bitonic_get_addr(reg_prc, jcp_.data_size, 0);
|
bitonic_get_addr(reg_prc, jcp_.data_size, 0);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_val_l.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
|
store(vmm_val_l, reg_aux_idx, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int));
|
bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int));
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_val_r.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
|
store(vmm_val_r, reg_aux_idx, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
|
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_idx_l.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
|
store_f32_i32(vmm_idx_l, reg_aux_idx, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
|
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_idx_r.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
|
store_f32_i32(vmm_idx_r, reg_aux_idx, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void topk_heap_sorting() {
|
inline void topk_heap_sorting() {
|
||||||
@ -480,9 +503,9 @@ private:
|
|||||||
|
|
||||||
// init dst
|
// init dst
|
||||||
mov(reg_i, 0);
|
mov(reg_i, 0);
|
||||||
sub(reg_heap_top_k, step);
|
sub(reg_heap_top_k, vector_step);
|
||||||
topk_heap_load(reg_heap_k_sub_step, step);
|
topk_heap_load(reg_heap_k_sub_step, vector_step);
|
||||||
add(reg_heap_top_k, step);
|
add(reg_heap_top_k, vector_step);
|
||||||
topk_heap_load(reg_heap_top_k, 1);
|
topk_heap_load(reg_heap_top_k, 1);
|
||||||
mov(reg_zero, 0);
|
mov(reg_zero, 0);
|
||||||
|
|
||||||
@ -579,7 +602,7 @@ private:
|
|||||||
Xbyak::Label topk_init_loop_end_label;
|
Xbyak::Label topk_init_loop_end_label;
|
||||||
L(topk_init_loop_label);
|
L(topk_init_loop_label);
|
||||||
{
|
{
|
||||||
if (s == step) {
|
if (s == vector_step) {
|
||||||
cmp(reg_i, reg_end);
|
cmp(reg_i, reg_end);
|
||||||
jg(topk_init_loop_end_label, T_NEAR);
|
jg(topk_init_loop_end_label, T_NEAR);
|
||||||
} else {
|
} else {
|
||||||
@ -588,25 +611,18 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
get_addr_by_reg_idx(reg_heap_outer_aux, reg_src, reg_i, jcp_.data_size);
|
get_addr_by_reg_idx(reg_heap_outer_aux, reg_src, reg_i, jcp_.data_size);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_heap_outer_aux.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
|
load(reg_heap_outer_aux, vmm_tmp, s);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, s),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst, reg_i, jcp_.data_size);
|
get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst, reg_i, jcp_.data_size);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_heap_outer_aux.getIdx())},
|
store(vmm_tmp, reg_heap_outer_aux, s);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, s),
|
if (s == vector_step) {
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
if (s == step) {
|
|
||||||
table_to_vmm(vmm_tmp, reg_heap_seq_idx, reg_i, 0, sizeof(int));
|
table_to_vmm(vmm_tmp, reg_heap_seq_idx, reg_i, 0, sizeof(int));
|
||||||
} else {
|
} else {
|
||||||
get_addr_by_reg_idx(reg_heap_outer_aux, reg_heap_seq_idx, reg_i, sizeof(int));
|
get_addr_by_reg_idx(reg_heap_outer_aux, reg_heap_seq_idx, reg_i, sizeof(int));
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_heap_outer_aux.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
|
load_i32(reg_heap_outer_aux, vmm_tmp, 1);
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, 1),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
}
|
}
|
||||||
get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst_idx, reg_i, sizeof(int));
|
get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst_idx, reg_i, sizeof(int));
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_heap_outer_aux.getIdx())},
|
store_i32(vmm_tmp, reg_heap_outer_aux, s);
|
||||||
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, s),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
|
|
||||||
add(reg_i, s);
|
add(reg_i, s);
|
||||||
jmp(topk_init_loop_label, T_NEAR);
|
jmp(topk_init_loop_label, T_NEAR);
|
||||||
@ -822,19 +838,19 @@ private:
|
|||||||
Xbyak::Label topk_main_loop_end_label;
|
Xbyak::Label topk_main_loop_end_label;
|
||||||
L(topk_main_loop_label);
|
L(topk_main_loop_label);
|
||||||
{
|
{
|
||||||
cmp(reg_work_amount, step);
|
cmp(reg_work_amount, vector_step);
|
||||||
jl(topk_main_loop_end_label, T_NEAR);
|
jl(topk_main_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
if (jcp_.bubble_inplace) {
|
if (jcp_.bubble_inplace) {
|
||||||
topk_bubble_inplace(step);
|
topk_bubble_inplace(vector_step);
|
||||||
} else {
|
} else {
|
||||||
topk_bubble(step);
|
topk_bubble(vector_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
add(reg_src, step * jcp_.data_size);
|
add(reg_src, vector_step * jcp_.data_size);
|
||||||
add(reg_dst, step * jcp_.data_size);
|
add(reg_dst, vector_step * jcp_.data_size);
|
||||||
add(reg_dst_idx, step * sizeof(int));
|
add(reg_dst_idx, vector_step * sizeof(int));
|
||||||
sub(reg_work_amount, step);
|
sub(reg_work_amount, vector_step);
|
||||||
|
|
||||||
jmp(topk_main_loop_label, T_NEAR);
|
jmp(topk_main_loop_label, T_NEAR);
|
||||||
}
|
}
|
||||||
@ -842,12 +858,12 @@ private:
|
|||||||
|
|
||||||
// tail
|
// tail
|
||||||
if (jcp_.bubble_inplace) {
|
if (jcp_.bubble_inplace) {
|
||||||
if (tail) {
|
if (tail_step) {
|
||||||
Xbyak::Label topk_tail_loop_end_label;
|
Xbyak::Label topk_tail_loop_end_label;
|
||||||
cmp(reg_work_amount, tail);
|
cmp(reg_work_amount, tail_step);
|
||||||
jl(topk_tail_loop_end_label, T_NEAR);
|
jl(topk_tail_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
topk_bubble_inplace(tail);
|
topk_bubble_inplace(tail_step);
|
||||||
|
|
||||||
L(topk_tail_loop_end_label);
|
L(topk_tail_loop_end_label);
|
||||||
}
|
}
|
||||||
@ -1025,19 +1041,13 @@ private:
|
|||||||
je(topk_init_loop_end_label, T_NEAR);
|
je(topk_init_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i);
|
get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
|
load(reg_tmp, vmm_tmp, elt_num);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
get_addr_by_reg_idx(reg_tmp, reg_dst, reg_block_sort_stride_byte, reg_i);
|
get_addr_by_reg_idx(reg_tmp, reg_dst, reg_block_sort_stride_byte, reg_i);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
|
store(vmm_tmp, reg_tmp, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
|
|
||||||
table_to_vmm(vmm_tmp, reg_bubble_block_idx, reg_i, 0, vlen);
|
table_to_vmm(vmm_tmp, reg_bubble_block_idx, reg_i, 0, vlen);
|
||||||
get_addr_by_reg_idx(reg_tmp, reg_dst_idx, reg_block_sort_stride_byte, sizeof(int) / jcp_.data_size, reg_i);
|
get_addr_by_reg_idx(reg_tmp, reg_dst_idx, reg_block_sort_stride_byte, sizeof(int) / jcp_.data_size, reg_i);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
|
store_i32(vmm_tmp, reg_tmp, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
|
|
||||||
add(reg_i, 1);
|
add(reg_i, 1);
|
||||||
jmp(topk_init_loop_label, T_NEAR);
|
jmp(topk_init_loop_label, T_NEAR);
|
||||||
@ -1057,9 +1067,7 @@ private:
|
|||||||
je(topk_update_loop_end_label, T_NEAR);
|
je(topk_update_loop_end_label, T_NEAR);
|
||||||
|
|
||||||
get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i);
|
get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_val_r.getIdx())},
|
load(reg_tmp, vmm_val_r, elt_num);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
|
|
||||||
table_to_vmm(vmm_idx_r, reg_bubble_block_idx, reg_i, 0, vlen);
|
table_to_vmm(vmm_idx_r, reg_bubble_block_idx, reg_i, 0, vlen);
|
||||||
uni_vcvtdq2ps(vmm_idx_r, vmm_idx_r);
|
uni_vcvtdq2ps(vmm_idx_r, vmm_idx_r);
|
||||||
@ -1152,9 +1160,7 @@ private:
|
|||||||
inline void topk_bubble_inplace(int elt_num) {
|
inline void topk_bubble_inplace(int elt_num) {
|
||||||
// load
|
// load
|
||||||
for (int i = 0; i < jcp_.top_k; i++) {
|
for (int i = 0; i < jcp_.top_k; i++) {
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(i).getIdx())},
|
load(reg_src, vmm_val(i), elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
uni_vmovdqu(vmm_idx(i), table_val(i));
|
uni_vmovdqu(vmm_idx(i), table_val(i));
|
||||||
uni_vcvtdq2ps(vmm_idx(i), vmm_idx(i));
|
uni_vcvtdq2ps(vmm_idx(i), vmm_idx(i));
|
||||||
}
|
}
|
||||||
@ -1165,9 +1171,7 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = jcp_.top_k; i < jcp_.axis_dim; i++) {
|
for (int i = jcp_.top_k; i < jcp_.axis_dim; i++) {
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(jcp_.top_k).getIdx())},
|
load(reg_src, vmm_val(jcp_.top_k), elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
uni_vmovdqu(vmm_idx(jcp_.top_k), table_val(i));
|
uni_vmovdqu(vmm_idx(jcp_.top_k), table_val(i));
|
||||||
uni_vcvtdq2ps(vmm_idx(jcp_.top_k), vmm_idx(jcp_.top_k));
|
uni_vcvtdq2ps(vmm_idx(jcp_.top_k), vmm_idx(jcp_.top_k));
|
||||||
for (int j = jcp_.top_k; j > 0; j--) {
|
for (int j = jcp_.top_k; j > 0; j--) {
|
||||||
@ -1183,12 +1187,8 @@ private:
|
|||||||
}
|
}
|
||||||
// store
|
// store
|
||||||
for (int i = 0; i < jcp_.top_k; i++) {
|
for (int i = 0; i < jcp_.top_k; i++) {
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_val(i).getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
|
store(vmm_val(i), reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size),
|
store_f32_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_idx(i).getIdx())}, {static_cast<size_t>(reg_dst_idx.getIdx())},
|
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1211,15 +1211,11 @@ private:
|
|||||||
|
|
||||||
L(topk_load_sort_label);
|
L(topk_load_sort_label);
|
||||||
{
|
{
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(0).getIdx())},
|
load(reg_src, vmm_val(0), vector_step, 0);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step, 0),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
uni_vmovdqu(vmm_idx(0), table_bubble_seq_idx(0));
|
uni_vmovdqu(vmm_idx(0), table_bubble_seq_idx(0));
|
||||||
uni_vcvtdq2ps(vmm_idx(0), vmm_idx(0));
|
uni_vcvtdq2ps(vmm_idx(0), vmm_idx(0));
|
||||||
if (isa == cpu::x64::sse41) {
|
if (isa == cpu::x64::sse41) {
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(1).getIdx())},
|
load(reg_src, vmm_val(1), vector_step, 4 * jcp_.data_size);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step, 4 * jcp_.data_size),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
uni_vmovdqu(vmm_idx(1), table_bubble_seq_idx(4));
|
uni_vmovdqu(vmm_idx(1), table_bubble_seq_idx(4));
|
||||||
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
||||||
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
||||||
@ -1235,17 +1231,13 @@ private:
|
|||||||
jg(topk_iter_end_label, T_NEAR);
|
jg(topk_iter_end_label, T_NEAR);
|
||||||
|
|
||||||
get_addr_by_reg_idx(reg_aux, reg_src, reg_i, jcp_.data_size, reg_seq_sort_stride);
|
get_addr_by_reg_idx(reg_aux, reg_src, reg_i, jcp_.data_size, reg_seq_sort_stride);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_aux.getIdx())}, {static_cast<size_t>(vmm_val(1).getIdx())},
|
load(reg_aux, vmm_val(1), vector_step);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int));
|
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int));
|
||||||
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
||||||
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
||||||
if (isa == cpu::x64::sse41) {
|
if (isa == cpu::x64::sse41) {
|
||||||
add(reg_aux, 4 * jcp_.data_size);
|
add(reg_aux, 4 * jcp_.data_size);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_aux.getIdx())}, {static_cast<size_t>(vmm_val(1).getIdx())},
|
load(reg_aux, vmm_val(1), vector_step);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 4, sizeof(int));
|
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 4, sizeof(int));
|
||||||
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
||||||
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
||||||
@ -1538,16 +1530,13 @@ private:
|
|||||||
// load l
|
// load l
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst);
|
add(reg_tmp, reg_dst);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_val_l.getIdx())},
|
load(reg_tmp, vmm_val_l, elt_num);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst_idx);
|
add(reg_tmp, reg_dst_idx);
|
||||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_idx_l.getIdx())},
|
load_i32_f32(reg_tmp, vmm_idx_l, elt_num);
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
|
|
||||||
// load r
|
// load r
|
||||||
Xbyak::Label topk_load_jmp_label;
|
Xbyak::Label topk_load_jmp_label;
|
||||||
@ -1557,16 +1546,14 @@ private:
|
|||||||
add(reg_tmp_64, reg_block_sort_stride_byte);
|
add(reg_tmp_64, reg_block_sort_stride_byte);
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst);
|
add(reg_tmp, reg_dst);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_val_r.getIdx())},
|
load(reg_tmp, vmm_val_r, elt_num);
|
||||||
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst_idx);
|
add(reg_tmp, reg_dst_idx);
|
||||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_idx_r.getIdx())},
|
load_i32_f32(reg_tmp, vmm_idx_r, elt_num);
|
||||||
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
|
|
||||||
{}, {load_pool_gpr_idxs});
|
|
||||||
sub(reg_tmp_64, reg_block_sort_stride_byte);
|
sub(reg_tmp_64, reg_block_sort_stride_byte);
|
||||||
}
|
}
|
||||||
L(topk_load_jmp_label);
|
L(topk_load_jmp_label);
|
||||||
@ -1576,16 +1563,13 @@ private:
|
|||||||
// store l
|
// store l
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst);
|
add(reg_tmp, reg_dst);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_val_l.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
|
store(vmm_val_l, reg_tmp, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst_idx);
|
add(reg_tmp, reg_dst_idx);
|
||||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_idx_l.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
|
store_f32_i32(vmm_idx_l, reg_tmp, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
|
|
||||||
// store r
|
// store r
|
||||||
Xbyak::Label topk_store_jmp_label;
|
Xbyak::Label topk_store_jmp_label;
|
||||||
@ -1595,16 +1579,13 @@ private:
|
|||||||
add(reg_tmp_64, reg_block_sort_stride_byte);
|
add(reg_tmp_64, reg_block_sort_stride_byte);
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst);
|
add(reg_tmp, reg_dst);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_val_r.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
|
store(vmm_val_r, reg_tmp, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
mov(reg_tmp, reg_tmp_64);
|
mov(reg_tmp, reg_tmp_64);
|
||||||
add(reg_tmp, reg_dst_idx);
|
add(reg_tmp, reg_dst_idx);
|
||||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||||
store_emitter->emit_code({static_cast<size_t>(vmm_idx_r.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
|
store_f32_i32(vmm_idx_r, reg_tmp, elt_num);
|
||||||
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
|
|
||||||
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
|
|
||||||
}
|
}
|
||||||
L(topk_store_jmp_label);
|
L(topk_store_jmp_label);
|
||||||
}
|
}
|
||||||
|
@ -133,6 +133,11 @@ InferenceEngine::Precision type2precision<uint8_t>() {
|
|||||||
return InferenceEngine::Precision::U8;
|
return InferenceEngine::Precision::U8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
InferenceEngine::Precision type2precision<int8_t>() {
|
||||||
|
return InferenceEngine::Precision::I8;
|
||||||
|
}
|
||||||
|
|
||||||
cpu_isa_t get_current_isa() {
|
cpu_isa_t get_current_isa() {
|
||||||
if (mayiuse(cpu_isa_t::avx512_core))
|
if (mayiuse(cpu_isa_t::avx512_core))
|
||||||
return cpu_isa_t::avx512_core;
|
return cpu_isa_t::avx512_core;
|
||||||
@ -212,9 +217,7 @@ const void * consts_table::store(const void *data, size_t size) {
|
|||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
jit_kernel::jit_kernel()
|
jit_kernel::jit_kernel()
|
||||||
: jit_generator()
|
: jit_generator() {
|
||||||
, _load_emitter(this, internal::get_current_isa())
|
|
||||||
, _store_emitter(this, internal::get_current_isa()) {
|
|
||||||
_free_rmmregs.reserve(16);
|
_free_rmmregs.reserve(16);
|
||||||
_free_rmmregs.reserve(16);
|
_free_rmmregs.reserve(16);
|
||||||
|
|
||||||
@ -297,10 +300,10 @@ void jit_kernel::free<Zmm>(const Zmm & reg) {
|
|||||||
|
|
||||||
void jit_kernel::postamble() {
|
void jit_kernel::postamble() {
|
||||||
jit_generator::postamble();
|
jit_generator::postamble();
|
||||||
if (_is_load_emitter_used)
|
for (const auto& emitter : _emitters) {
|
||||||
_load_emitter.emit_data();
|
if (emitter.second)
|
||||||
if (_is_store_emitter_used)
|
emitter.second->emit_data();
|
||||||
_store_emitter.emit_data();
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const AddressFrame & jit_kernel::address_frame(size_t size) const {
|
const AddressFrame & jit_kernel::address_frame(size_t size) const {
|
||||||
|
@ -697,11 +697,8 @@ struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator {
|
|||||||
private:
|
private:
|
||||||
reg_indices _free_x64regs;
|
reg_indices _free_x64regs;
|
||||||
reg_indices _free_rmmregs;
|
reg_indices _free_rmmregs;
|
||||||
bool _is_load_emitter_used = false;
|
|
||||||
bool _is_store_emitter_used = false;
|
|
||||||
jit_load_emitter _load_emitter;
|
|
||||||
jit_store_emitter _store_emitter;
|
|
||||||
internal::consts_table _consts;
|
internal::consts_table _consts;
|
||||||
|
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> _emitters;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -746,17 +743,18 @@ void jit_kernel::load(const variable<DstT[N]> & dst, const variable<SrcT> & src,
|
|||||||
const std::vector<size_t> pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end());
|
const std::vector<size_t> pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end());
|
||||||
const std::vector<size_t> pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end());
|
const std::vector<size_t> pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end());
|
||||||
|
|
||||||
_load_emitter.emit_code(
|
const auto src_prc = internal::type2precision<src_type>();
|
||||||
|
const auto dst_prc = internal::type2precision<dst_type>();
|
||||||
|
|
||||||
|
const auto key = load_emitter_params(src_prc, dst_prc, length).hash();
|
||||||
|
if (!_emitters[key]) {
|
||||||
|
_emitters[key].reset(new jit_load_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length));
|
||||||
|
}
|
||||||
|
_emitters[key]->emit_code(
|
||||||
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(src).getIdx()) },
|
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(src).getIdx()) },
|
||||||
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(dst).getIdx()) },
|
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(dst).getIdx()) },
|
||||||
std::make_shared<load_emitter_context>(
|
|
||||||
internal::type2precision<src_type>(),
|
|
||||||
internal::type2precision<dst_type>(),
|
|
||||||
static_cast<int>(length)),
|
|
||||||
pool_vec_idxs,
|
pool_vec_idxs,
|
||||||
pool_gpr_idxs);
|
pool_gpr_idxs);
|
||||||
|
|
||||||
_is_load_emitter_used = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename DstT, size_t N, typename SrcT>
|
template<typename DstT, size_t N, typename SrcT>
|
||||||
@ -788,17 +786,18 @@ void jit_kernel::store(const variable<DstT> & dst, const variable<SrcT[N]> & src
|
|||||||
const std::vector<size_t> pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end());
|
const std::vector<size_t> pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end());
|
||||||
const std::vector<size_t> pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end());
|
const std::vector<size_t> pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end());
|
||||||
|
|
||||||
_store_emitter.emit_code(
|
const auto src_prc = internal::type2precision<src_type>();
|
||||||
|
const auto dst_prc = internal::type2precision<dst_type>();
|
||||||
|
|
||||||
|
const auto key = store_emitter_params(src_prc, dst_prc, length).hash();
|
||||||
|
if (!_emitters[key]) {
|
||||||
|
_emitters[key].reset(new jit_store_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length));
|
||||||
|
}
|
||||||
|
_emitters[key]->emit_code(
|
||||||
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(src).getIdx()) },
|
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(src).getIdx()) },
|
||||||
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(dst).getIdx()) },
|
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(dst).getIdx()) },
|
||||||
std::make_shared<store_emitter_context>(
|
|
||||||
internal::type2precision<src_type>(),
|
|
||||||
internal::type2precision<dst_type>(),
|
|
||||||
static_cast<int>(length)),
|
|
||||||
pool_vec_idxs,
|
pool_vec_idxs,
|
||||||
pool_gpr_idxs);
|
pool_gpr_idxs);
|
||||||
|
|
||||||
_is_store_emitter_used = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename DstT, typename SrcT, size_t N>
|
template<typename DstT, typename SrcT, size_t N>
|
||||||
|
@ -318,15 +318,30 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST(JitKernel, variable_load_and_store) {
|
TEST(JitKernel, variable_load_and_store) {
|
||||||
jit_variable_load_store_test_kernel<uint8_t, float> kernel;
|
{
|
||||||
if (mayiuse(cpu_isa_t::avx512_core)) {
|
jit_variable_load_store_test_kernel<uint8_t, float> kernel;
|
||||||
kernel.test<16>();
|
if (mayiuse(cpu_isa_t::avx512_core)) {
|
||||||
|
kernel.test<16>();
|
||||||
|
}
|
||||||
|
if (mayiuse(cpu_isa_t::avx2)) {
|
||||||
|
kernel.test<8>();
|
||||||
|
}
|
||||||
|
if (mayiuse(cpu_isa_t::sse41)) {
|
||||||
|
kernel.test<4>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (mayiuse(cpu_isa_t::avx2)) {
|
|
||||||
kernel.test<8>();
|
{
|
||||||
}
|
jit_variable_load_store_test_kernel<int8_t, int8_t> kernel;
|
||||||
if (mayiuse(cpu_isa_t::sse41)) {
|
if (mayiuse(cpu_isa_t::avx512_core)) {
|
||||||
kernel.test<4>();
|
kernel.test<16>();
|
||||||
|
}
|
||||||
|
if (mayiuse(cpu_isa_t::avx2)) {
|
||||||
|
kernel.test<8>();
|
||||||
|
}
|
||||||
|
if (mayiuse(cpu_isa_t::sse41)) {
|
||||||
|
kernel.test<4>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user