[CPU] Removed Contexts from load and store emitters (#12446)

This commit is contained in:
Alexandra Sidorova 2022-08-17 22:22:22 +04:00 committed by GitHub
parent 97f3d84cf5
commit 6d6f52806b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1101 additions and 1023 deletions

View File

@ -23,6 +23,11 @@ enum emitter_in_out_map {
gpr_to_gpr,
};
// structure for storage of emitter parameters to hash in map
struct emitter_params {
virtual size_t hash() const = 0;
};
struct emitter_context {
virtual ~emitter_context() = default;
};

View File

@ -18,95 +18,125 @@ using namespace Xbyak::util;
namespace ov {
namespace intel_cpu {
size_t load_emitter_params::hash() const {
size_t seed = 0;
seed = hash_combine(seed, std::string("jit_load_emitter"));
seed = hash_combine(seed, src_prc_.getPrecVal());
seed = hash_combine(seed, dst_prc_.getPrecVal());
seed = hash_combine(seed, load_num_);
seed = hash_combine(seed, is_fill_);
seed = hash_combine(seed, fill_value_);
return seed;
}
size_t store_emitter_params::hash() const {
size_t seed = 0;
seed = hash_combine(seed, std::string("jit_store_emitter"));
seed = hash_combine(seed, src_prc_.getPrecVal());
seed = hash_combine(seed, dst_prc_.getPrecVal());
seed = hash_combine(seed, store_num_);
return seed;
}
static int get_aux_regs_for_avx512_mask(const size_t byte_size, const bool is_fill = false) {
if (mayiuse(cpu::x64::avx512_core)) {
if (!one_of(byte_size, 64, 32, 16) || is_fill) {
return 1;
}
}
return 0;
}
/// LOAD ///
jit_load_emitter::jit_load_emitter(jit_generator *host, cpu_isa_t host_isa,
Precision exec_prc, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), name("unknown") {
jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
Precision src_prc, Precision dst_prc, int load_num, Precision exec_prc,
bool is_fill, std::string fill_value, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), load_num_(load_num), src_prc_(src_prc), dst_prc_(dst_prc),
is_fill_(is_fill), fill_value_(fill_value), name_("unknown") {
prepare_table();
v_len_elt = get_vec_length() / exec_prc.size();
load_size_ = load_num * src_prc.size();
v_len_elt_ = get_vec_length() / exec_prc.size();
}
size_t jit_load_emitter::get_inputs_num() const { return 1; }
// 0 for temp reg for mask load, 1 for table address
size_t jit_load_emitter::aux_gprs_count() const {
return 2;
// 0 for temp reg for mask load in avx512 if needed
int count = get_aux_regs_for_avx512_mask(load_num_ * dst_prc_.size(), is_fill_);
// 1 for table address
if (is_fill_)
count++;
return count;
}
void jit_load_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) const {
const auto* load_emitter_context = dynamic_cast<const ov::intel_cpu::load_emitter_context*>(emit_context);
if (load_emitter_context == nullptr) {
IE_THROW() << "Load emitter in " << name << " does not get load emmiter context.";
}
const int offset = in_idxs.size() == 2 ? in_idxs[1] : 0;
if (host_isa_ == cpu::x64::sse41) {
emit_isa<cpu::x64::sse41>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast<int>(out_idxs[0]),
load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_);
emit_isa<cpu::x64::sse41>(Reg64(in_idxs[0]), static_cast<int>(out_idxs[0]), offset);
} else if (host_isa_ == cpu::x64::avx2) {
emit_isa<cpu::x64::avx2>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast<int>(out_idxs[0]),
load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_);
emit_isa<cpu::x64::avx2>(Reg64(in_idxs[0]), static_cast<int>(out_idxs[0]), offset);
} else if (host_isa_ == cpu::x64::avx512_core) {
emit_isa<cpu::x64::avx512_core>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast<int>(out_idxs[0]),
load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_);
emit_isa<cpu::x64::avx512_core>(Reg64(in_idxs[0]), static_cast<int>(out_idxs[0]), offset);
} else {
IE_THROW() << "Load emitter in " << name << " is performed on unsupported isa(at least x64::sse41).";
IE_THROW() << "Load emitter in " << name_ << " is performed on unsupported isa(at least x64::sse41).";
}
}
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void jit_load_emitter::emit_isa(const Xbyak::Reg64 &reg_src, int offset_byte, InferenceEngine::Precision src_prc,
const int out_vec_idx, InferenceEngine::Precision dst_prc, int load_num, bool is_fill, std::string fill_value) const {
bool matched_prc = (dst_prc == src_prc) || (dst_prc == Precision::FP32) || (dst_prc == Precision::I32);
void jit_load_emitter::emit_isa(const Xbyak::Reg64 &reg_src, const int out_vec_idx, const int offset) const {
bool matched_prc = (dst_prc_ == src_prc_) || (dst_prc_ == Precision::FP32) || (dst_prc_ == Precision::I32);
if (!matched_prc) {
IE_THROW() << "Load emitter in " << name << " only support output precision of FP32 or I32 or the same precision as input.";
IE_THROW() << "Load emitter in " << name_ << " only support output precision of FP32 or I32 or the same precision as input.";
}
if (load_num > (get_vec_length() / dst_prc.size())) {
IE_THROW() << "Load emitter in " << name << " have unexpected number of elements to load.";
if (load_num_ > (get_vec_length() / dst_prc_.size())) {
IE_THROW() << "Load emitter in " << name_ << " have unexpected number of elements to load.";
}
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xmm, isa == cpu::x64::avx2, Ymm, Zmm>::type;
// pure load
if (src_prc == dst_prc) {
load_bytes<Vmm>(Vmm(out_vec_idx), reg_src, offset_byte, load_num * src_prc.size(), is_fill, fill_value);
if (src_prc_ == dst_prc_) {
load_bytes<Vmm>(Vmm(out_vec_idx), reg_src, offset, load_size_);
} else {
// "pure load" + convert. dst_prc must be FP32 or I32.
switch (src_prc) {
switch (src_prc_) {
case Precision::FP32:
case Precision::I32:
load_bytes<Vmm>(Vmm(out_vec_idx), reg_src, offset_byte, load_num * src_prc.size(), is_fill, fill_value);
load_bytes<Vmm>(Vmm(out_vec_idx), reg_src, offset, load_size_);
break;
case Precision::I8:
load_bytes_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset_byte, true, load_num * src_prc.size(), is_fill, fill_value);
load_bytes_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset, true, load_size_);
break;
case Precision::U8:
load_bytes_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset_byte, false, load_num * src_prc.size(), is_fill, fill_value);
load_bytes_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset, false, load_size_);
break;
case Precision::I16:
load_words_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset_byte, false, true, load_num * src_prc.size(), is_fill, fill_value);
load_words_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset, false, true, load_size_);
break;
case Precision::U16:
load_words_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset_byte, false, false, load_num * src_prc.size(), is_fill, fill_value);
load_words_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset, false, false, load_size_);
break;
case Precision::BF16:
load_words_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset_byte, true, false, load_num * src_prc.size(), is_fill, fill_value);
load_words_to_dword_extension<Vmm>(Vmm(out_vec_idx), reg_src, offset, true, false, load_size_);
break;
default:
IE_THROW() << "Load emitter in " << name << " has unsupported src precision to load.";
IE_THROW() << "Load emitter in " << name_ << " has unsupported src precision to load.";
}
}
// post convert between I32 and FP32
if (src_prc != dst_prc) {
switch (dst_prc) {
if (src_prc_ != dst_prc_) {
switch (dst_prc_) {
case Precision::FP32:
if ((src_prc != Precision::FP32) && (src_prc != Precision::BF16))
if ((src_prc_ != Precision::FP32) && (src_prc_ != Precision::BF16))
h->uni_vcvtdq2ps(Vmm(out_vec_idx), Vmm(out_vec_idx));
break;
case Precision::I32:
if ((src_prc == Precision::FP32) || (src_prc == Precision::BF16)) {
if ((src_prc_ == Precision::FP32) || (src_prc_ == Precision::BF16)) {
h->uni_vcvtps2dq(Vmm(out_vec_idx), Vmm(out_vec_idx));
}
break;
@ -129,8 +159,7 @@ void jit_load_emitter::emit_isa(const Xbyak::Reg64 &reg_src, int offset_byte, In
*
*/
template <typename Vmm>
void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, int load_size,
bool is_fill, std::string fill_value) const {
void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, int load_size) const {
constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
@ -141,12 +170,12 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
// Ensure data fits completely inside the Xmm/Ymm/Zmm register
if (load_size < 0 || load_size > 64)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_byte.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_byte.";
// check if proper number bytes fit inside the Xmm/Ymm register
if (is_ymm && load_size > 32)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_byte.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_byte.";
if (is_xmm && load_size > 16)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_byte.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_byte.";
auto xmm = Xbyak::Xmm(vmm.getIdx());
auto ymm = Xbyak::Ymm(vmm.getIdx());
@ -229,7 +258,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
break;
case 16: break;
default:
IE_THROW() << "Load emitter in " << name<< " has unexpected number of values to load in load_byte.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_byte.";
}
if (has_xmm_block) {
@ -270,8 +299,8 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
}
}
if (is_fill)
fill_with_default(vmm, fill_value, load_size / 4);
if (is_fill_)
fill_with_default(vmm, fill_value_, load_size / 4);
}
/**
@ -294,8 +323,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
*/
template <typename Vmm>
void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg,
int offset, bool is_signed, int load_size, bool is_fill, std::string fill_value) const {
void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_signed, int load_size) const {
constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
@ -308,11 +336,11 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak
// For Ymm register, load capacity is halved (32 * load_size <= 256)
// For Xmm register, load capacity is halved further (32 * load_size <= 128)
if (load_size < 0 || load_size > 16)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_bytes_to_dword_extension.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_bytes_to_dword_extension.";
if (is_ymm && load_size > 8)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_bytes_to_dword_extension.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_bytes_to_dword_extension.";
if (is_xmm && load_size > 4)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_bytes_to_dword_extension.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_bytes_to_dword_extension.";
// For load_size == 4/8/16, do load/extension in one go
switch (load_size) {
@ -365,8 +393,8 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak
}
}
if (is_fill)
fill_with_default(vmm, fill_value, load_size);
if (is_fill_)
fill_with_default(vmm, fill_value_, load_size);
}
/**
@ -388,8 +416,7 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak
* [0.. 32] for ZMM version of the function. i.e. 16 words -> 16 * 32 bit == 512 bit
*/
template <typename Vmm>
void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg,
int offset, bool is_bf16, bool is_signed, int load_size, bool is_fill, std::string fill_value) const {
void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_bf16, bool is_signed, int load_size) const {
constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
@ -402,11 +429,11 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak
// For Ymm register, load capacity is halved (16/2(num) * 32 <= 128)
// For Xmm register, load capacity is halved again (8/2(num) * 32 <= 128)
if (load_size < 0 || load_size > 32)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_words_to_dword_extension.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_words_to_dword_extension.";
if (is_ymm && load_size > 16)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_words_to_dword_extension.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_words_to_dword_extension.";
if (is_xmm && load_size > 8)
IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_words_to_dword_extension.";
IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_words_to_dword_extension.";
auto xmm = Xbyak::Xmm(vmm.getIdx());
auto ymm = Xbyak::Ymm(vmm.getIdx());
@ -483,12 +510,12 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak
}
}
if (is_fill)
fill_with_default(vmm, fill_value, load_size / 2);
if (is_fill_)
fill_with_default(vmm, fill_value_, load_size / 2);
}
template <typename Vmm>
void jit_load_emitter::fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const {
void jit_load_emitter::fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const {
constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
@ -504,9 +531,10 @@ template <typename Vmm>
h->kmovq(k_mask, Reg64(aux_gpr_idxs[0]));
h->vblendmps(vmm | k_mask, vmm, table_val(fill_value));
}
}
}
void jit_load_emitter::register_table_entries() {
if (is_fill_) {
push_arg_entry_of("zero", 0x00000000, true);
push_arg_entry_of("int_one", 0x00000001, true);
push_arg_entry_of("float_one", 0x3f800000, true);
@ -514,117 +542,125 @@ void jit_load_emitter::register_table_entries() {
push_arg_entry_of("float_min", 0xff7fffff, true);
push_arg_entry_of("int32_max", 0x4effffff, true);
push_arg_entry_of("float_max", 0x7f7fffff, true);
}
/// STORE ///
jit_store_emitter::jit_store_emitter(jit_generator *host, cpu_isa_t host_isa,
Precision exec_prc, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), name("unknown") {
v_len_elt = get_vec_length() / exec_prc.size();
if (!mayiuse(cpu::x64::avx512_core_bf16) && mayiuse(cpu::x64::avx512_core)) {
emu_vcvtneps2bf16.reset(new jit_emu_vcvtneps2bf16(host, host_isa));
}
}
// 0 for temp reg for mask store
size_t jit_store_emitter::aux_gprs_count() const {
return 1;
/// STORE ///
jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
Precision src_prc, Precision dst_prc, int store_num, Precision exec_prc, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), store_num_(store_num), src_prc_(src_prc), dst_prc_(dst_prc), name_("unknown") {
v_len_elt_ = get_vec_length() / exec_prc.size();
store_size_ = store_num * dst_prc.size();
if (!mayiuse(cpu::x64::avx512_core_bf16) && mayiuse(cpu::x64::avx512_core)) {
emu_vcvtneps2bf16_.reset(new jit_emu_vcvtneps2bf16(host, host_isa));
}
}
// 0 for temp reg for mask store for avx512
size_t jit_store_emitter::aux_gprs_count() const {
return get_aux_regs_for_avx512_mask(store_num_ * src_prc_.size());
}
// zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and restore status)
size_t jit_store_emitter::aux_vecs_count() const {
return 1;
int count = 0;
// to avoid src vmm pollution after data type conversion
if ((src_prc_.is_float() && !dst_prc_.is_float()) ||
(!src_prc_.is_float() && dst_prc_.is_float()) ||
(src_prc_ == Precision::FP32 && dst_prc_ == Precision::BF16))
count++;
// zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and restore status)
if (mayiuse(cpu::x64::avx512_core) && one_of(dst_prc_, Precision::U8, Precision::U16))
count++;
return count;
}
size_t jit_store_emitter::get_inputs_num() const { return 1; }
void jit_store_emitter::emit_data() const {
if (emu_vcvtneps2bf16)
emu_vcvtneps2bf16->emit_data();
if (emu_vcvtneps2bf16_)
emu_vcvtneps2bf16_->emit_data();
}
void jit_store_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) const {
const auto* store_emitter_context = dynamic_cast<const ov::intel_cpu::store_emitter_context*>(emit_context);
if (store_emitter_context == nullptr) {
IE_THROW() << "Store emitter in " << name << " does not get store emmiter context.";
}
const int offset = in_idxs.size() == 2 ? in_idxs[1] : 0;
if (host_isa_ == cpu::x64::sse41) {
emit_isa<cpu::x64::sse41>(static_cast<int>(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]),
store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_);
emit_isa<cpu::x64::sse41>(static_cast<int>(in_idxs[0]), Reg64(out_idxs[0]), offset);
} else if (host_isa_ == cpu::x64::avx2) {
emit_isa<cpu::x64::avx2>(static_cast<int>(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]),
store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_);
emit_isa<cpu::x64::avx2>(static_cast<int>(in_idxs[0]), Reg64(out_idxs[0]), offset);
} else if (host_isa_ == cpu::x64::avx512_core) {
emit_isa<cpu::x64::avx512_core>(static_cast<int>(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]),
store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_);
emit_isa<cpu::x64::avx512_core>(static_cast<int>(in_idxs[0]), Reg64(out_idxs[0]), offset);
} else {
IE_THROW() << "Store emitter in " << name << " is performed on unsupported isa(at least x64::sse41).";
IE_THROW() << "Store emitter in " << name_ << " is performed on unsupported isa(at least x64::sse41).";
}
}
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void jit_store_emitter::emit_isa(const int in_vec_idx, InferenceEngine::Precision src_prc,
const Xbyak::Reg64 &reg_dst, int offset_byte, InferenceEngine::Precision dst_prc, int store_num) const {
bool matched_prc = (src_prc == dst_prc) || (src_prc == Precision::FP32) || (src_prc == Precision::I32);
void jit_store_emitter::emit_isa(const int in_vec_idx, const Xbyak::Reg64 &reg_dst, const int offset) const {
bool matched_prc = (src_prc_ == dst_prc_) || (src_prc_ == Precision::FP32) || (src_prc_ == Precision::I32);
if (!matched_prc) {
IE_THROW() << "Store emitter in " << name << " only support input precision of FP32 or I32 or the same precision as output.";
IE_THROW() << "Store emitter in " << name_ << " only support input precision of FP32 or I32 or the same precision as output.";
}
if ((src_prc == Precision::FP32) || (src_prc == Precision::I32)) {
if ((isa == cpu::x64::sse41 && store_num > 4) || (isa == cpu::x64::avx2 && store_num > 8) ||
(isa == cpu::x64::avx512_core && store_num > 16) || store_num < 0) {
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store.";
if ((src_prc_ == Precision::FP32) || (src_prc_ == Precision::I32)) {
if ((isa == cpu::x64::sse41 && store_num_ > 4) || (isa == cpu::x64::avx2 && store_num_ > 8) ||
(isa == cpu::x64::avx512_core && store_num_ > 16) || store_num_ < 0) {
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store.";
}
}
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xmm, isa == cpu::x64::avx2, Ymm, Zmm>::type;
if (src_prc != dst_prc) {
switch (src_prc) {
int data_idx = in_vec_idx;
if (src_prc_ != dst_prc_) {
switch (src_prc_) {
case Precision::FP32:
if ((dst_prc != Precision::FP32) && (dst_prc != Precision::BF16)) {
h->uni_vcvtps2dq(Vmm(in_vec_idx), Vmm(in_vec_idx));
if ((dst_prc_ != Precision::FP32) && (dst_prc_ != Precision::BF16)) {
h->uni_vcvtps2dq(Vmm(aux_vec_idxs.back()), Vmm(data_idx));
data_idx = aux_vec_idxs.back();
}
break;
case Precision::I32:
if ((dst_prc == Precision::FP32) || (dst_prc == Precision::BF16))
h->uni_vcvtdq2ps(Vmm(in_vec_idx), Vmm(in_vec_idx));
if ((dst_prc_ == Precision::FP32) || (dst_prc_ == Precision::BF16)) {
h->uni_vcvtdq2ps(Vmm(aux_vec_idxs.back()), Vmm(data_idx));
data_idx = aux_vec_idxs.back();
}
break;
default:
break;
}
}
if (src_prc == dst_prc) {
store_bytes<Vmm>(Vmm(in_vec_idx), reg_dst, offset_byte, store_num * dst_prc.size());
if (src_prc_ == dst_prc_) {
store_bytes<Vmm>(Vmm(data_idx), reg_dst, offset, store_size_);
} else {
switch (dst_prc) {
switch (dst_prc_) {
case Precision::FP32:
case Precision::I32:
store_bytes<Vmm>(Vmm(in_vec_idx), reg_dst, offset_byte, store_num * dst_prc.size());
store_bytes<Vmm>(Vmm(data_idx), reg_dst, offset, store_size_);
break;
case Precision::I8:
store_dword_to_byte_extension<Vmm>(Vmm(in_vec_idx), reg_dst, offset_byte, true, store_num);
store_dword_to_byte_extension<Vmm>(Vmm(data_idx), reg_dst, offset, true, store_num_);
break;
case Precision::U8:
store_dword_to_byte_extension<Vmm>(Vmm(in_vec_idx), reg_dst, offset_byte, false, store_num);
store_dword_to_byte_extension<Vmm>(Vmm(data_idx), reg_dst, offset, false, store_num_);
break;
case Precision::I16:
store_dword_to_word_extension<Vmm>(Vmm(in_vec_idx), reg_dst, offset_byte, false, true, store_num);
store_dword_to_word_extension<Vmm>(Vmm(data_idx), reg_dst, offset, false, true, store_num_);
break;
case Precision::U16:
store_dword_to_word_extension<Vmm>(Vmm(in_vec_idx), reg_dst, offset_byte, false, false, store_num);
store_dword_to_word_extension<Vmm>(Vmm(data_idx), reg_dst, offset, false, false, store_num_);
break;
case Precision::BF16:
store_dword_to_word_extension<Vmm>(Vmm(in_vec_idx), reg_dst, offset_byte, true, false, store_num);
store_dword_to_word_extension<Vmm>(Vmm(data_idx), reg_dst, offset, true, false, store_num_);
break;
default:
IE_THROW() << "Store emitter in " << name << " has unsupported dst precision to store.";
IE_THROW() << "Store emitter in " << name_ << " has unsupported dst precision to store.";
}
}
}
}
/**
* store_bytes is the utility function to facilitate storing of
* store_size (0 <= store_size <= 64) many contiguous bytes from the Xmm/Ymm/Zmm
@ -641,7 +677,7 @@ template <dnnl::impl::cpu::x64::cpu_isa_t isa>
*
*/
template <typename Vmm>
void jit_store_emitter::store_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, int store_size) const {
void jit_store_emitter::store_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, int store_size) const {
constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
@ -652,11 +688,11 @@ template <typename Vmm>
// Ensure data fits completely inside the Xmm/Ymm/Zmm register
if (store_size < 0 || store_size > 64)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_bytes.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_bytes.";
if (is_ymm && store_size > 32)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_bytes.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_bytes.";
if (is_xmm && store_size > 16)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_bytes.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_bytes.";
auto xmm = Xbyak::Xmm(vmm.getIdx());
auto ymm = Xbyak::Ymm(vmm.getIdx());
@ -738,7 +774,7 @@ template <typename Vmm>
break;
case 16: break;
default:
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_bytes.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_bytes.";
}
};
@ -764,7 +800,7 @@ template <typename Vmm>
}
break;
}
}
}
/**
* store_dword_to_byte_extension is the utility function to
@ -772,7 +808,7 @@ template <typename Vmm>
* 2. store the packed byte into the memory referenced by ptr[reg + offset] address.
*/
template <typename Vmm>
void jit_store_emitter::store_dword_to_byte_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_signed, int store_num) const {
void jit_store_emitter::store_dword_to_byte_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_signed, int store_num) const {
constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
@ -785,11 +821,11 @@ template <typename Vmm>
// At most 8 dwords can fit inside the Ymm register
// At most 4 dwords can fit inside the Xmm register
if (store_num < 0 || store_num > 16)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_dword_to_byte_extension.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_dword_to_byte_extension.";
if (is_ymm && store_num > 8)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_dword_to_byte_extension.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_dword_to_byte_extension.";
if (is_xmm && store_num > 4)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_dword_to_byte_extension.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_dword_to_byte_extension.";
auto ymm = Xbyak::Ymm(vmm.getIdx());
auto xmm = Xbyak::Xmm(vmm.getIdx());
@ -877,7 +913,7 @@ template <typename Vmm>
}
break;
}
}
}
/**
* store_dword_to_word_extension is the utility function to
@ -885,8 +921,8 @@ template <typename Vmm>
* 2. store the packed words into the memory referenced by ptr[reg + offset] address.
*/
template <typename Vmm>
void jit_store_emitter::store_dword_to_word_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset,
bool is_bf16, bool is_signed, int store_num) const {
void jit_store_emitter::store_dword_to_word_extension(const Vmm &vmm, const Xbyak::Reg64 &reg,
int offset, bool is_bf16, bool is_signed, int store_num) const {
constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
@ -899,11 +935,11 @@ template <typename Vmm>
// At most 4 dwords can fit inside the Xmm register
// At most 8 dwords can fit inside the Ymm register
if (store_num < 0 || store_num > 16)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_dword_to_word_extension.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_dword_to_word_extension.";
if (is_ymm && store_num > 8)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_dword_to_word_extension.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_dword_to_word_extension.";
if (is_xmm && store_num > 4)
IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_dword_to_word_extension.";
IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_dword_to_word_extension.";
auto xmm = Xbyak::Xmm(vmm.getIdx());
auto ymm = Xbyak::Ymm(vmm.getIdx());
@ -926,10 +962,14 @@ template <typename Vmm>
};
if (is_bf16) {
// to avoid src vmm pollution
if (src_prc_ == Precision::FP32) {
ymm = Ymm(aux_vec_idxs[0]);
}
if (mayiuse(cpu::x64::avx512_core_bf16)) {
h->vcvtneps2bf16(ymm, zmm);
} else {
emu_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm.getIdx())}, {static_cast<size_t>(ymm.getIdx())});
emu_vcvtneps2bf16_->emit_code({static_cast<size_t>(vmm.getIdx())}, {static_cast<size_t>(ymm.getIdx())});
}
if (store_num == 16) {
h->vmovdqu16(ptr[reg + offset], ymm);
@ -996,7 +1036,7 @@ template <typename Vmm>
break;
}
}
}
}
} // namespace intel_cpu
} // namespace ov

View File

@ -15,40 +15,37 @@ using namespace InferenceEngine;
namespace ov {
namespace intel_cpu {
struct load_emitter_context : public emitter_context {
load_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), load_num_(8),
offset_byte_(0), is_fill_(false), fill_value_("zero") {}
struct load_emitter_params : public emitter_params {
load_emitter_params(Precision src_prc, Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero"):
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value) {}
load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, int offset_byte = 0, bool is_fill = false, std::string fill_value = "zero"):
src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), offset_byte_(offset_byte), is_fill_(is_fill), fill_value_(fill_value) {}
size_t hash() const override;
int offset_byte_;
int load_num_;
Precision src_prc_;
Precision dst_prc_;
int load_num_;
bool is_fill_;
std::string fill_value_;
};
struct store_emitter_context : public emitter_context {
store_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32),
store_num_(8), offset_byte_(0) {}
struct store_emitter_params : public emitter_params {
store_emitter_params(Precision src_prc, Precision dst_prc, int store_num):
src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num) {}
store_emitter_context(Precision src_prc, Precision dst_prc, int store_num, int offset_byte = 0)
: src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num), offset_byte_(offset_byte) {}
size_t hash() const override;
int offset_byte_;
int store_num_;
Precision src_prc_;
Precision dst_prc_;
int store_num_;
};
class jit_load_emitter : public jit_emitter {
public:
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec);
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, Precision src_prc, Precision dst_prc, int load_num,
Precision exec_prc = Precision::FP32, bool is_fill = false, std::string fill_value = "zero",
emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec);
/**
* load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc.
* load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc, where offset_byte is in_idxs[1]
* is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values.
* fill_value: when load_num can not fully fit in vector register, what values should be filled as default values.
* currently support "zero", "int_one", "float_one", "int32_min", "float_min", "int32_max" and "float_max".
@ -73,20 +70,16 @@ public:
private:
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const Xbyak::Reg64 &reg_src, int offset_byte, InferenceEngine::Precision src_prc,
const int out_vec_idx, InferenceEngine::Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero") const;
void emit_isa(const Xbyak::Reg64 &reg_src, const int out_vec_idx, const int offset) const;
template <typename Vmm>
void load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, int load_size,
bool is_fill = false, std::string fill_value = "zero") const;
void load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, int load_size) const;
template <typename Vmm>
void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_signed, int load_size,
bool is_fill = false, std::string fill_value = "zero") const;
void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_signed, int load_size) const;
template <typename Vmm>
void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_bf16, bool is_signed, int load_size,
bool is_fill = false, std::string fill_value = "zero") const;
void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, bool is_bf16, bool is_signed, int load_size) const;
template <typename Vmm>
void fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const;
@ -95,17 +88,23 @@ private:
size_t aux_gprs_count() const override;
std::string name;
int v_len_elt; // 4/8/16
std::string name_;
int v_len_elt_; // 4/8/16
int load_num_;
int load_size_;
Precision src_prc_;
Precision dst_prc_;
bool is_fill_;
std::string fill_value_;
};
class jit_store_emitter : public jit_emitter {
public:
jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr);
jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, Precision src_prc, Precision dst_prc, int store_num,
Precision exec_prc = Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr);
/**
* store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data.
* store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data, where offset_byte is in_idxs[1]
* supported src_prc and dst_prc pairs are as below(x indicate for support):
* FP32 I32 I16 U16 I8 U8 BF16 --> src_prc
* FP32 x x
@ -128,13 +127,12 @@ public:
void emit_data() const override;
std::shared_ptr<jit_emu_vcvtneps2bf16> get_emu_vcvtneps2bf16() const {
return emu_vcvtneps2bf16;
return emu_vcvtneps2bf16_;
}
private:
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const int in_vec_idx, InferenceEngine::Precision src_prc,
const Xbyak::Reg64 &reg_dst, int offset_byte, InferenceEngine::Precision dst_prc, int store_num) const;
void emit_isa(const int in_vec_idx, const Xbyak::Reg64 &reg_dst, const int offset) const;
template <typename Vmm>
void store_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int offset, int store_size) const;
@ -148,9 +146,13 @@ private:
size_t aux_gprs_count() const override;
size_t aux_vecs_count() const override;
std::string name;
int v_len_elt; // 4/8/16
std::shared_ptr<jit_emu_vcvtneps2bf16> emu_vcvtneps2bf16;
std::string name_;
int v_len_elt_; // 4/8/16
int store_num_;
int store_size_;
Precision src_prc_;
Precision dst_prc_;
std::shared_ptr<jit_emu_vcvtneps2bf16> emu_vcvtneps2bf16_;
};
} // namespace intel_cpu

View File

@ -58,9 +58,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
}
void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa));
store_emitter.reset(new jit_store_emitter(this, isa));
// dummy second reg_tmp_64 as no fill needed
load_pool_gpr_idxs = {static_cast<size_t>(reg_tmp_64.getIdx()), static_cast<size_t>(reg_tmp_64.getIdx())};
store_pool_gpr_idxs = {static_cast<size_t>(reg_tmp_64.getIdx())};
@ -162,8 +159,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi
this->postamble();
load_emitter->emit_data();
store_emitter->emit_data();
emit_emitters_data();
for (auto& inj : eltwise_injectors)
inj->prepare_table();
if ((jcp_.mode == InterpolateMode::cubic) && (jcp_.layout == InterpolateLayoutType::planar)) {
@ -176,6 +172,9 @@ private:
Xbyak::Ymm, Xbyak::Zmm>::type;
const int vlen = cpu_isa_traits<isa>::vlen;
const int vector_step = vlen / sizeof(float);
const int tail_step = jcp_.C % vector_step;
const int scalar_step = 1;
Xbyak::Reg64 reg_src = r8;
Xbyak::Reg64 reg_src_aux = r15;
@ -246,8 +245,8 @@ private:
Xbyak::Label l_table_constant;
Opmask k_mask = Xbyak::Opmask(1);
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
std::vector<size_t> store_pool_gpr_idxs;
std::vector<size_t> store_pool_vec_idxs;
std::vector<size_t> load_pool_gpr_idxs;
@ -256,20 +255,44 @@ private:
std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
std::vector<std::shared_ptr<jit_uni_quantization_injector_f32<isa>>> quantization_injectors;
inline void load(const Xbyak::Reg64& reg_src, Vmm& vmm, const int& elt_num, const int& offset = 0) {
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm.getIdx())},
std::make_shared<load_emitter_context>(jcp_.src_prc, Precision::FP32, elt_num, offset),
{}, {load_pool_gpr_idxs});
void emit_emitters_data() {
for (const auto& emitter : emitters) {
if (emitter.second)
emitter.second->emit_data();
}
inline void store(const Vmm& vmm, const Xbyak::Reg64& reg_dst, const int& elt_num, const int& offset = 0) {
store_emitter->emit_code({static_cast<size_t>(vmm.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.dst_prc, elt_num, offset),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
}
inline void load_weights(const Xbyak::Reg64& reg_weights, Vmm& vmm, const int& elt_num, const int& offset = 0) {
load_emitter->emit_code({static_cast<size_t>(reg_weights.getIdx())}, {static_cast<size_t>(vmm.getIdx())},
std::make_shared<load_emitter_context>(Precision::FP32, Precision::FP32, elt_num, offset),
{}, {load_pool_gpr_idxs});
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, jcp_.src_prc, Precision::FP32, elt_num, offset);
}
inline void load_weights(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, Precision::FP32, Precision::FP32, elt_num, offset);
}
inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num));
}
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), static_cast<size_t>(offset)},
{static_cast<size_t>(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs});
}
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
const auto seed = store_emitter_params(Precision::FP32, jcp_.dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, elt_num));
}
// for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm_dst.getIdx()) };
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
emitters[seed]->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(offset)},
{static_cast<size_t>(reg_dst.getIdx())},
{local_store_pool_vec_idxs}, {store_pool_gpr_idxs});
}
void nn_planar() {
@ -303,7 +326,6 @@ private:
// reset index_w, index_w * dataSize done when built to avoid redundent compute
mov(reg_index, reg_index_w);
int step = vlen / sizeof(float);
Xbyak::Label nn_loop_label;
Xbyak::Label nn_loop_end_label;
@ -312,7 +334,7 @@ private:
L(nn_loop_label); // inner loop
{
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(nn_loop_end_label, T_NEAR);
uni_vmovdqu(vmm_index, ptr[reg_index]);
@ -320,17 +342,16 @@ private:
vgatherdps(vmm_val, ptr[reg_src_h + vmm_index], vmm_mask);
if (attr_.post_ops_.len() != 0)
apply_post_ops(jcp_.dst_prc, 1);
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, vector_step);
add(reg_dst, step * jcp_.dst_data_size);
add(reg_index, step * jcp_.indices_size);
sub(reg_work_amount, step);
add(reg_dst, vector_step * jcp_.dst_data_size);
add(reg_index, vector_step * jcp_.indices_size);
sub(reg_work_amount, vector_step);
jmp(nn_loop_label, T_NEAR);
}
L(nn_loop_end_label);
step = 1;
L(nn_tail_loop_label);
{
cmp(reg_work_amount, 1);
@ -340,14 +361,14 @@ private:
mov(reg_index_offset, dword[reg_index]);
add(reg_src_aux, reg_index_offset);
load(reg_src_aux, vmm_val, step);
load(reg_src_aux, vmm_val, scalar_step);
if (attr_.post_ops_.len() != 0)
apply_post_ops(jcp_.dst_prc, 1);
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, scalar_step);
add(reg_dst, step * jcp_.dst_data_size);
add(reg_index, step * jcp_.indices_size);
sub(reg_work_amount, step);
add(reg_dst, scalar_step * jcp_.dst_data_size);
add(reg_index, scalar_step * jcp_.indices_size);
sub(reg_work_amount, scalar_step);
jmp(nn_tail_loop_label, T_NEAR);
}
@ -363,8 +384,6 @@ private:
}
void nn_blk() {
int step = vlen / sizeof(float);
Xbyak::Label nn_loop_label;
Xbyak::Label nn_loop_end_label;
L(nn_loop_label);
@ -376,22 +395,22 @@ private:
mov(reg_index_offset, dword[reg_index]);
add(reg_src_aux, reg_index_offset);
load(reg_src_aux, vmm_val, step);
load(reg_src_aux, vmm_val, vector_step);
if (attr_.post_ops_.len() != 0)
apply_post_ops(jcp_.dst_prc, 0);
store(vmm_val, reg_dst, step);
add(reg_dst, step * jcp_.dst_data_size);
store(vmm_val, reg_dst, vector_step);
add(reg_dst, vector_step * jcp_.dst_data_size);
if (isa == cpu::x64::sse41) {
add(reg_src_aux, step * jcp_.src_data_size);
load(reg_src_aux, vmm_val, step);
add(reg_src_aux, vector_step * jcp_.src_data_size);
load(reg_src_aux, vmm_val, vector_step);
if (attr_.post_ops_.len() != 0) {
add(reg_oc_off, step * sizeof(float));
add(reg_oc_off, vector_step * sizeof(float));
apply_post_ops(jcp_.dst_prc, 0);
sub(reg_oc_off, step * sizeof(float));
sub(reg_oc_off, vector_step * sizeof(float));
}
store(vmm_val, reg_dst, step);
add(reg_dst, step * jcp_.dst_data_size);
store(vmm_val, reg_dst, vector_step);
add(reg_dst, vector_step * jcp_.dst_data_size);
}
add(reg_index, jcp_.indices_size);
@ -421,8 +440,6 @@ private:
cmp(reg_work_amount_out, 1);
jl(out_loop_end, T_NEAR);
int step = vlen / sizeof(float);
//inner loop for C
Xbyak::Label nn_loop_label;
Xbyak::Label nn_loop_end_label;
@ -444,35 +461,34 @@ private:
L(nn_loop_label);
{
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(nn_loop_end_label, T_NEAR);
load(reg_src_aux, vmm_val, step);
load(reg_src_aux, vmm_val, vector_step);
if (attr_.post_ops_.len() != 0)
apply_post_ops(jcp_.dst_prc, 0);
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, vector_step);
add(reg_dst, step * jcp_.dst_data_size);
add(reg_src_aux, step * jcp_.src_data_size);
add(reg_oc_off, step * sizeof(float));
sub(reg_work_amount, step);
add(reg_dst, vector_step * jcp_.dst_data_size);
add(reg_src_aux, vector_step * jcp_.src_data_size);
add(reg_oc_off, vector_step * sizeof(float));
sub(reg_work_amount, vector_step);
jmp(nn_loop_label, T_NEAR);
}
L(nn_loop_end_label);
int tail_num = jcp_.C % step;
if (tail_num != 0) {
load(reg_src_aux, vmm_val, tail_num);
if (tail_step != 0) {
load(reg_src_aux, vmm_val, tail_step);
if (attr_.post_ops_.len() != 0)
apply_post_ops(jcp_.dst_prc, 0);
store(vmm_val, reg_dst, tail_num);
store(vmm_val, reg_dst, tail_step);
// check to remove below
add(reg_dst, tail_num * jcp_.dst_data_size);
add(reg_src_aux, tail_num * jcp_.src_data_size);
add(reg_oc_off, tail_num * sizeof(float));
sub(reg_work_amount, tail_num);
add(reg_dst, tail_step * jcp_.dst_data_size);
add(reg_src_aux, tail_step * jcp_.src_data_size);
add(reg_oc_off, tail_step * sizeof(float));
sub(reg_work_amount, tail_step);
}
add(reg_index, jcp_.indices_size);
sub(reg_work_amount_out, 1);
@ -519,11 +535,10 @@ private:
}
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
int step = vlen / sizeof(float);
int blk = (isa == cpu::x64::sse41) ? (2 * step) : step;
int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (step * jcp_.dst_data_size) :
int blk = (isa == cpu::x64::sse41) ? (2 * vector_step) : vector_step;
int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.dst_data_size) :
(blk * jcp_.OW * jcp_.OH * jcp_.OD * jcp_.dst_data_size);
int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (step * jcp_.src_data_size) :
int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.src_data_size) :
(blk * jcp_.IW * jcp_.IH * jcp_.ID * jcp_.src_data_size);
Xbyak::Label main_loop_label;
@ -535,29 +550,29 @@ private:
L(main_loop_label);
{
if (jcp_.layout == InterpolateLayoutType::by_channel) {
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(main_loop_end_label, T_NEAR);
} else {
cmp(reg_work_amount, 1);
jl(main_loop_end_label, T_NEAR);
}
// progressive manner
load(reg_src, vmm_valTL, step);
load(reg_src_aux, vmm_valTR, step);
load(reg_src, vmm_valTL, vector_step);
load(reg_src_aux, vmm_valTR, vector_step);
if (jcp_.spatial_dim_size == 1) {
linear_onnx_worker_1d();
}
if (jcp_.spatial_dim_size > 1) {
load(reg_src_aux1, vmm_valBL, step);
load(reg_src_aux2, vmm_valBR, step);
load(reg_src_aux1, vmm_valBL, vector_step);
load(reg_src_aux2, vmm_valBR, vector_step);
linear_onnx_worker_2d();
}
if (jcp_.spatial_dim_size > 2) {
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
load(reg_src_aux4, vmm_valTL, step);
load(reg_src_aux5, vmm_valTR, step);
load(reg_src_aux6, vmm_valBL, step);
load(reg_src_aux7, vmm_valBR, step);
load(reg_src_aux4, vmm_valTL, vector_step);
load(reg_src_aux5, vmm_valTR, vector_step);
load(reg_src_aux6, vmm_valBL, vector_step);
load(reg_src_aux7, vmm_valBR, vector_step);
// 2d for end depth
linear_onnx_worker_2d();
@ -568,28 +583,28 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR
add(reg_oc_off, step * sizeof(float));
add(reg_oc_off, vector_step * sizeof(float));
}
store(vmm_valTR, reg_dst, step);
store(vmm_valTR, reg_dst, vector_step);
if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) {
int offset_src = step * jcp_.src_data_size;
load(reg_src, vmm_valTL, step, offset_src);
load(reg_src_aux, vmm_valTR, step, offset_src);
int offset_src = vector_step * jcp_.src_data_size;
load(reg_src, vmm_valTL, vector_step, offset_src);
load(reg_src_aux, vmm_valTR, vector_step, offset_src);
if (jcp_.spatial_dim_size == 1) {
linear_onnx_worker_1d();
}
if (jcp_.spatial_dim_size > 1) {
load(reg_src_aux1, vmm_valBL, step, offset_src);
load(reg_src_aux2, vmm_valBR, step, offset_src);
load(reg_src_aux1, vmm_valBL, vector_step, offset_src);
load(reg_src_aux2, vmm_valBR, vector_step, offset_src);
linear_onnx_worker_2d();
}
if (jcp_.spatial_dim_size > 2) {
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
load(reg_src_aux4, vmm_valTL, step, offset_src);
load(reg_src_aux5, vmm_valTR, step, offset_src);
load(reg_src_aux6, vmm_valBL, step, offset_src);
load(reg_src_aux7, vmm_valBR, step, offset_src);
load(reg_src_aux4, vmm_valTL, vector_step, offset_src);
load(reg_src_aux5, vmm_valTR, vector_step, offset_src);
load(reg_src_aux6, vmm_valBL, vector_step, offset_src);
load(reg_src_aux7, vmm_valBR, vector_step, offset_src);
// 2d for end depth
linear_onnx_worker_2d();
// 3th dimension
@ -599,10 +614,10 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, false);
add(reg_oc_off, step * sizeof(float));
add(reg_oc_off, vector_step * sizeof(float));
}
int offset_dst = step * jcp_.dst_data_size;
store(vmm_valTR, reg_dst, step, offset_dst);
int offset_dst = vector_step * jcp_.dst_data_size;
store(vmm_valTR, reg_dst, vector_step, offset_dst);
}
add(reg_dst, dst_stride);
add(reg_src, src_stride);
@ -618,7 +633,7 @@ private:
add(reg_src_aux7, src_stride);
}
if (jcp_.layout == InterpolateLayoutType::by_channel) {
sub(reg_work_amount, step); // work_amount is c
sub(reg_work_amount, vector_step); // work_amount is c
} else {
sub(reg_work_amount, 1); // work_amount = div_up(c, blk), no tails
}
@ -627,25 +642,24 @@ private:
}
L(main_loop_end_label);
int tail_num = jcp_.C % step;
if ((jcp_.layout == InterpolateLayoutType::by_channel) && (tail_num != 0)) {
load(reg_src, vmm_valTL, tail_num);
load(reg_src_aux, vmm_valTR, tail_num);
if ((jcp_.layout == InterpolateLayoutType::by_channel) && (tail_step != 0)) {
load(reg_src, vmm_valTL, tail_step);
load(reg_src_aux, vmm_valTR, tail_step);
if (jcp_.spatial_dim_size == 1) {
linear_onnx_worker_1d();
}
if (jcp_.spatial_dim_size > 1) {
load(reg_src_aux1, vmm_valBL, tail_num);
load(reg_src_aux2, vmm_valBR, tail_num);
load(reg_src_aux1, vmm_valBL, tail_step);
load(reg_src_aux2, vmm_valBR, tail_step);
linear_onnx_worker_2d();
}
if (jcp_.spatial_dim_size > 2) {
uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm
load(reg_src_aux4, vmm_valTL, tail_num);
load(reg_src_aux5, vmm_valTR, tail_num);
load(reg_src_aux6, vmm_valBL, tail_num);
load(reg_src_aux7, vmm_valBR, tail_num);
load(reg_src_aux4, vmm_valTL, tail_step);
load(reg_src_aux5, vmm_valTR, tail_step);
load(reg_src_aux6, vmm_valBL, tail_step);
load(reg_src_aux7, vmm_valBR, tail_step);
// 2d for end depth
linear_onnx_worker_2d();
// 3th dimension
@ -655,10 +669,10 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR
add(reg_oc_off, tail_num * sizeof(float));
add(reg_oc_off, tail_step * sizeof(float));
}
store(vmm_valTR, reg_dst, tail_num);
store(vmm_valTR, reg_dst, tail_step);
}
}
@ -669,7 +683,6 @@ private:
mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0])]);
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
int step = vlen / sizeof(float);
int index_stride = jcp_.OW * jcp_.OH * jcp_.OD * jcp_.indices_size;
int weight_stride = jcp_.OW * jcp_.OH * jcp_.OD * sizeof(float);
@ -679,7 +692,7 @@ private:
Xbyak::Label tail_loop_end_label;
L(main_loop_label);
{
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(main_loop_end_label, T_NEAR);
uni_vmovdqu(vmm_index, ptr[reg_index]);
@ -690,8 +703,8 @@ private:
uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
vgatherdps(vmm_valTR, ptr[reg_src + vmm_index], vmm_mask);
load_weights(reg_src_aux, vmm_weightL, step);
load_weights(reg_src_aux, vmm_weightR, step, weight_stride);
load_weights(reg_src_aux, vmm_weightL, vector_step);
load_weights(reg_src_aux, vmm_weightR, vector_step, weight_stride);
// progressive manner
if (jcp_.spatial_dim_size == 1) {
@ -706,8 +719,8 @@ private:
uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
vgatherdps(vmm_valBR, ptr[reg_src + vmm_index], vmm_mask);
load_weights(reg_src_aux, vmm_weightT, step, 2 * weight_stride);
load_weights(reg_src_aux, vmm_weightB, step, 3 * weight_stride);
load_weights(reg_src_aux, vmm_weightT, vector_step, 2 * weight_stride);
load_weights(reg_src_aux, vmm_weightB, vector_step, 3 * weight_stride);
linear_onnx_worker_2d();
}
@ -733,8 +746,8 @@ private:
linear_onnx_worker_2d();
load_weights(reg_src_aux, vmm_weightE, step, 5 * weight_stride);
load_weights(reg_src_aux, vmm_weightF, step, 4 * weight_stride);
load_weights(reg_src_aux, vmm_weightE, vector_step, 5 * weight_stride);
load_weights(reg_src_aux, vmm_weightF, vector_step, 4 * weight_stride);
uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight
uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight
@ -743,18 +756,17 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, true); // vmm_val is vmm_valTR, broadcase is true
}
store(vmm_valTR, reg_dst, step);
store(vmm_valTR, reg_dst, vector_step);
add(reg_dst, step * jcp_.dst_data_size);
add(reg_src_aux, step * sizeof(float));
add(reg_index, step * jcp_.indices_size);
sub(reg_work_amount, step);
add(reg_dst, vector_step * jcp_.dst_data_size);
add(reg_src_aux, vector_step * sizeof(float));
add(reg_index, vector_step * jcp_.indices_size);
sub(reg_work_amount, vector_step);
jmp(main_loop_label, T_NEAR);
}
L(main_loop_end_label);
step = 1;
L(tail_loop_label);
{
cmp(reg_work_amount, 1);
@ -763,15 +775,15 @@ private:
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valTL, step);
load(reg_src_aux1, vmm_valTL, scalar_step);
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index + index_stride]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valTR, step);
load(reg_src_aux1, vmm_valTR, scalar_step);
load_weights(reg_src_aux, vmm_weightL, step, 0);
load_weights(reg_src_aux, vmm_weightR, step, weight_stride);
load_weights(reg_src_aux, vmm_weightL, scalar_step, 0);
load_weights(reg_src_aux, vmm_weightR, scalar_step, weight_stride);
if (jcp_.spatial_dim_size == 1) {
linear_onnx_worker_1d();
@ -780,15 +792,15 @@ private:
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index + 2 * index_stride]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valBL, step);
load(reg_src_aux1, vmm_valBL, scalar_step);
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index + 3 * index_stride]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valBR, step);
load(reg_src_aux1, vmm_valBR, scalar_step);
load_weights(reg_src_aux, vmm_weightT, step, 2 * weight_stride);
load_weights(reg_src_aux, vmm_weightB, step, 3 * weight_stride);
load_weights(reg_src_aux, vmm_weightT, scalar_step, 2 * weight_stride);
load_weights(reg_src_aux, vmm_weightB, scalar_step, 3 * weight_stride);
linear_onnx_worker_2d();
}
@ -799,27 +811,27 @@ private:
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index + 4 * index_stride]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valTL, step);
load(reg_src_aux1, vmm_valTL, scalar_step);
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index + 5 * index_stride]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valTR, step);
load(reg_src_aux1, vmm_valTR, scalar_step);
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index + 6 * index_stride]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valBL, step);
load(reg_src_aux1, vmm_valBL, scalar_step);
mov(reg_src_aux1, reg_src);
mov(reg_index_offset, dword[reg_index + 7 * index_stride]);
add(reg_src_aux1, reg_index_offset);
load(reg_src_aux1, vmm_valBR, step);
load(reg_src_aux1, vmm_valBR, scalar_step);
linear_onnx_worker_2d();
load_weights(reg_src_aux, vmm_weightE, step, 5 * weight_stride);
load_weights(reg_src_aux, vmm_weightF, step, 4 * weight_stride);
load_weights(reg_src_aux, vmm_weightE, scalar_step, 5 * weight_stride);
load_weights(reg_src_aux, vmm_weightF, scalar_step, 4 * weight_stride);
uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight
uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight
@ -828,12 +840,12 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, true); // process on vmm_val, vmm_val is vmm_valTR, and bc
}
store(vmm_valTR, reg_dst, step);
store(vmm_valTR, reg_dst, scalar_step);
add(reg_dst, step * jcp_.dst_data_size);
add(reg_src_aux, step * sizeof(float));
add(reg_index, step * jcp_.indices_size);
sub(reg_work_amount, step);
add(reg_dst, scalar_step * jcp_.dst_data_size);
add(reg_src_aux, scalar_step * sizeof(float));
add(reg_index, scalar_step * jcp_.indices_size);
sub(reg_work_amount, scalar_step);
jmp(tail_loop_label, T_NEAR);
}
@ -876,8 +888,7 @@ private:
uni_vbroadcastss(vmm_weightY2, ptr[reg_src_aux1 + 2 * sizeof(float)]);
uni_vbroadcastss(vmm_weightY3, ptr[reg_src_aux1 + 3 * sizeof(float)]);
int step = vlen / sizeof(float);
int blk = (isa == cpu::x64::sse41) ? (2 * step) : step;
int blk = (isa == cpu::x64::sse41) ? (2 * vector_step) : vector_step;
Xbyak::Label main_loop_label;
Xbyak::Label main_loop_end_label;
@ -886,7 +897,7 @@ private:
L(main_loop_label);
{
if (jcp_.layout == InterpolateLayoutType::by_channel) {
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(main_loop_end_label, T_NEAR);
} else {
cmp(reg_work_amount, 1);
@ -899,14 +910,14 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value to post_ops and store
add(reg_oc_off, step * sizeof(float));
add(reg_oc_off, vector_step * sizeof(float));
}
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, vector_step);
if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) {
// vmm is xmm here
add(reg_src, step * jcp_.src_data_size);
add(reg_dst, step * jcp_.dst_data_size);
add(reg_src, vector_step * jcp_.src_data_size);
add(reg_dst, vector_step * jcp_.dst_data_size);
uni_vpxor(vmm_val, vmm_val, vmm_val);
@ -914,19 +925,19 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, false);
add(reg_oc_off, step * sizeof(float)); // second step for one blk
add(reg_oc_off, vector_step * sizeof(float)); // second vector_step for one blk
}
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, vector_step);
sub(reg_src, step * jcp_.src_data_size);
sub(reg_dst, step * jcp_.dst_data_size);
sub(reg_src, vector_step * jcp_.src_data_size);
sub(reg_dst, vector_step * jcp_.dst_data_size);
}
if (jcp_.layout == InterpolateLayoutType::by_channel) {
int dst_stride = step * jcp_.dst_data_size;
int src_stride = step * jcp_.src_data_size;
int dst_stride = vector_step * jcp_.dst_data_size;
int src_stride = vector_step * jcp_.src_data_size;
add(reg_dst, dst_stride);
add(reg_src, src_stride);
sub(reg_work_amount, step); // work_amount is c
sub(reg_work_amount, vector_step); // work_amount is c
} else {
int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size;
int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size;
@ -940,7 +951,6 @@ private:
L(main_loop_end_label);
// only for by_channel layout for tails.
step = 1;
L(tail_loop_label);
{
cmp(reg_work_amount, 1);
@ -953,15 +963,15 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value
add(reg_oc_off, step * sizeof(float));
add(reg_oc_off, scalar_step * sizeof(float));
}
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, scalar_step);
int dst_stride = step * jcp_.dst_data_size;
int src_stride = step * jcp_.src_data_size;
int dst_stride = scalar_step * jcp_.dst_data_size;
int src_stride = scalar_step * jcp_.src_data_size;
add(reg_dst, dst_stride);
add(reg_src, src_stride);
sub(reg_work_amount, step); // work_amount is c
sub(reg_work_amount, scalar_step); // work_amount is c
jmp(tail_loop_label, T_NEAR);
}
@ -1020,7 +1030,6 @@ private:
mov(reg_weight_y, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]);
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
int step = vlen / sizeof(float);
int grid_len = 4;
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
@ -1035,7 +1044,7 @@ private:
Xbyak::Label tail_loop_end_label;
L(main_loop_label);
{
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(main_loop_end_label, T_NEAR);
// vmm_tbl_y: (0 0 0 0 1 1 1 1 * index_size) --> (0 0 0 0 4 4 4 4)
@ -1111,19 +1120,18 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel
}
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, vector_step);
add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence by dd()
add(reg_tbl_x, step * sizeof(int));
add(reg_dst, step * jcp_.dst_data_size);
add(reg_tbl_y, vector_step * sizeof(int)); // sizeof(int): sequence by dd()
add(reg_tbl_x, vector_step * sizeof(int));
add(reg_dst, vector_step * jcp_.dst_data_size);
sub(reg_work_amount, step);
sub(reg_work_amount, vector_step);
jmp(main_loop_label, T_NEAR);
}
L(main_loop_end_label);
step = 1;
L(tail_loop_label);
{
cmp(reg_work_amount, 1);
@ -1182,13 +1190,13 @@ private:
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel
}
store(vmm_val, reg_dst, step);
store(vmm_val, reg_dst, scalar_step);
add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence with dd()
add(reg_tbl_x, step * sizeof(int));
add(reg_dst, step * jcp_.dst_data_size);
add(reg_tbl_y, scalar_step * sizeof(int)); // sizeof(int): sequence with dd()
add(reg_tbl_x, scalar_step * sizeof(int));
add(reg_dst, scalar_step * jcp_.dst_data_size);
sub(reg_work_amount, step);
sub(reg_work_amount, scalar_step);
jmp(tail_loop_label, T_NEAR);
}
@ -1264,7 +1272,7 @@ private:
return ptr[reg_table + index * vlen];
}
// always gather to Vmm, compute with Vmm, store with Xmm if scalar
// always gather to Vmm, compute with Vmm, store with Xmm if scalar_step
inline void gather_i32_indices(Vmm vmm_src, const Xbyak::Reg64 &base, int offset, Vmm vmm_indices, int scale,
Precision src_prc, bool is_scalar) {
Xbyak::Address table_idx = ptr[base + offset + vmm_indices * scale];

View File

@ -110,7 +110,13 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
}
void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa));
tail_step = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / vector_step) * vector_step :
jcp_.C - (jcp_.C / vector_step) * vector_step;
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
load_vector_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, vector_step));
load_tail_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, tail_step));
load_tail_with_fill_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, tail_step, Precision::FP32, true));
this->preamble();
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
@ -134,14 +140,11 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
}
}
tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step :
jcp_.C - (jcp_.C / step) * step;
load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
if (jcp_.planar_layout) {
worker_unroll();
if (tail_num != 0) {
if (tail_step != 0) {
worker_tail_planar();
}
@ -198,7 +201,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
}
Xbyak::Label label_empty_2half_sse42;
if (tail_num == 0) {
if (tail_step == 0) {
cmp(reg_oc_off, static_cast<int>(jcp_.C * sizeof(float)));
jae(label_empty_2half_sse42, T_NEAR);
@ -210,7 +213,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
Xbyak::Label label_full_size;
Xbyak::Label label_size_end;
cmp(reg_oc_off, static_cast<int>((jcp_.C - step) * sizeof(float)));
cmp(reg_oc_off, static_cast<int>((jcp_.C - vector_step) * sizeof(float)));
jle(label_full_size, T_NEAR);
// no need care and fill rest
@ -251,7 +254,9 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
this->postamble();
load_emitter->emit_data();
load_vector_emitter->emit_data();
load_tail_emitter->emit_data();
load_tail_with_fill_emitter->emit_data();
}
private:
@ -259,8 +264,8 @@ private:
Xbyak::Ymm, Xbyak::Zmm>::type;
const int vlen = cpu_isa_traits<isa>::vlen;
const int step = vlen / sizeof(float);
int tail_num = 0;
const int vector_step = vlen / sizeof(float);
int tail_step = 0;
Xbyak::Reg64 reg_src = r8;
Xbyak::Reg64 reg_mean = r9;
@ -286,14 +291,14 @@ private:
Xbyak::Opmask k_mask = Xbyak::Opmask(7);
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::unique_ptr<jit_load_emitter> load_vector_emitter = nullptr;
std::unique_ptr<jit_load_emitter> load_tail_emitter = nullptr;
std::unique_ptr<jit_load_emitter> load_tail_with_fill_emitter = nullptr;
std::vector<size_t> load_pool_gpr_idxs;
inline void worker_full_size() {
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, step),
load_vector_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
{}, {load_pool_gpr_idxs});
if (jcp_.normalize_variance) {
@ -313,9 +318,7 @@ private:
}
inline void worker_tail_blk() {
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num),
load_tail_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
{}, {load_pool_gpr_idxs});
if (jcp_.normalize_variance) {
@ -357,9 +360,7 @@ private:
}
inline void worker_tail_planar() {
Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32;
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
std::make_shared<load_emitter_context>(jcp_.src_prc, dst_prc, tail_num, 0, true),
load_tail_with_fill_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
{}, {load_pool_gpr_idxs});
if (jcp_.normalize_variance) {
@ -371,15 +372,15 @@ private:
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
if (isa == cpu::x64::sse41) {
uint8 imm = 1;
imm = ~((imm << tail_num) - imm);
imm = ~((imm << tail_step) - imm);
blendps(vmm_val, vmm_zero, imm);
} else if (isa == cpu::x64::avx2) {
uint8 imm = 1;
imm = ~((imm << tail_num) - imm);
imm = ~((imm << tail_step) - imm);
vblendps(vmm_val, vmm_val, vmm_zero, imm);
} else if (isa == cpu::x64::avx512_core) {
uint64_t tail_mask = 1;
tail_mask = ~((tail_mask << tail_num) - tail_mask);
tail_mask = ~((tail_mask << tail_step) - tail_mask);
mov(reg_aux, tail_mask);
kmovq(k_mask, reg_aux);
vblendmps(vmm_val | k_mask, vmm_val, vmm_zero);
@ -435,8 +436,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
}
}
load_emitter.reset(new jit_load_emitter(this, isa));
store_emitter.reset(new jit_store_emitter(this, isa));
tail_step = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / vector_step) * vector_step :
jcp_.C - (jcp_.C / vector_step) * vector_step;
load_vector_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, Precision::FP32, vector_step));
load_tail_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, Precision::FP32, tail_step));
store_vector_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, vector_step));
store_tail_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, tail_step));
this->preamble();
@ -463,16 +469,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step :
jcp_.C - (jcp_.C / step) * step;
load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
store_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx())};
store_pool_vec_idxs = {static_cast<size_t>(vmm_zero.getIdx())};
store_pool_vec_idxs = {static_cast<size_t>(vmm_zero.getIdx()), static_cast<size_t>(vmm_val.getIdx())};
if (jcp_.planar_layout) {
worker_mvn_unroll();
if (tail_num != 0) {
if (tail_step != 0) {
worker_mvn(true);
}
} else {
@ -501,7 +504,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
}
Xbyak::Label label_empty_2half_sse42;
if (tail_num == 0) {
if (tail_step == 0) {
cmp(reg_oc_off, static_cast<int>(jcp_.C * sizeof(float)));
jae(label_empty_2half_sse42, T_NEAR);
worker_mvn_unroll();
@ -512,7 +515,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
Xbyak::Label label_full_size_block;
Xbyak::Label label_size_end;
cmp(reg_oc_off, static_cast<int>((jcp_.C - step) * sizeof(float)));
cmp(reg_oc_off, static_cast<int>((jcp_.C - vector_step) * sizeof(float)));
jle(label_full_size_block, T_NEAR);
worker_mvn_unroll(true);
@ -530,8 +533,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
this->postamble();
load_emitter->emit_data();
store_emitter->emit_data();
load_vector_emitter->emit_data();
load_tail_emitter->emit_data();
store_vector_emitter->emit_data();
store_tail_emitter->emit_data();
for (auto& inj : eltwise_injectors)
inj->prepare_table();
@ -542,8 +547,8 @@ private:
Xbyak::Ymm, Xbyak::Zmm>::type;
const int vlen = cpu_isa_traits<isa>::vlen;
const int step = vlen / sizeof(float);
int tail_num = 0;
const int vector_step = vlen / sizeof(float);
int tail_step = 0;
Xbyak::Reg64 reg_src = r8;
Xbyak::Reg64 reg_mean = r9;
@ -570,8 +575,10 @@ private:
Vmm vmm_d_weights = Vmm(5);
Vmm vmm_d_bias = Vmm(6);
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::unique_ptr<jit_load_emitter> load_vector_emitter = nullptr;
std::unique_ptr<jit_load_emitter> load_tail_emitter = nullptr;
std::unique_ptr<jit_store_emitter> store_vector_emitter = nullptr;
std::unique_ptr<jit_store_emitter> store_tail_emitter = nullptr;
std::vector<std::shared_ptr<jit_uni_eltwise_injector_f32<isa>>> eltwise_injectors;
std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
@ -582,9 +589,10 @@ private:
std::vector<size_t> load_pool_gpr_idxs;
inline void worker_mvn(bool is_tail) {
int elt_num = is_tail ? tail_num : step;
const auto& load_emitter = is_tail ? load_tail_emitter : load_vector_emitter;
const auto& store_emitter = is_tail ? store_tail_emitter : store_vector_emitter;
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val.getIdx())},
std::make_shared<load_emitter_context>(jcp_.src_prc, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
uni_vsubps(vmm_val, vmm_val, vmm_mean);
@ -594,7 +602,6 @@ private:
apply_post_ops(jcp_.dst_prc, jcp_.planar_layout);
store_emitter->emit_code({static_cast<size_t>(vmm_val.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.dst_prc, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
}

View File

@ -44,8 +44,9 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator
}
void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa));
store_emitter.reset(new jit_store_emitter(this, isa));
load_vector_emitter.reset(new jit_load_emitter(this, isa, Precision::FP32, Precision::FP32, vector_step));
load_scalar_emitter.reset(new jit_load_emitter(this, isa, Precision::FP32, Precision::FP32, scalar_step));
exp_injector.reset(new jit_uni_eltwise_injector_f32<isa>(this, dnnl::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f));
this->preamble();
@ -137,8 +138,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator
this->postamble();
load_emitter->emit_data();
store_emitter->emit_data();
load_vector_emitter->emit_data();
load_scalar_emitter->emit_data();
prepare_table();
exp_injector->prepare_table();
@ -147,6 +148,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator
private:
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm, isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
uint32_t vlen = cpu_isa_traits<isa>::vlen;
const int vector_step = vlen / sizeof(float);
const int scalar_step = 1;
Xbyak::Reg64 reg_boxes_coord0 = r8;
Xbyak::Reg64 reg_boxes_coord1 = r9;
@ -172,8 +175,8 @@ private:
Xbyak::Reg64 reg_params = abi_param1;
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::unique_ptr<jit_load_emitter> load_vector_emitter = nullptr;
std::unique_ptr<jit_load_emitter> load_scalar_emitter = nullptr;
std::vector<size_t> store_pool_gpr_idxs;
std::vector<size_t> store_pool_vec_idxs;
@ -205,25 +208,24 @@ private:
std::shared_ptr<jit_uni_eltwise_injector_f32<isa>> exp_injector;
inline void hard_nms() {
int step = vlen / sizeof(float);
Xbyak::Label main_loop_label_hard;
Xbyak::Label main_loop_end_label_hard;
Xbyak::Label tail_loop_label_hard;
Xbyak::Label terminate_label_hard;
L(main_loop_label_hard);
{
cmp(reg_boxes_num, step);
cmp(reg_boxes_num, vector_step);
jl(main_loop_end_label_hard, T_NEAR);
sub(reg_boxes_coord0, step * sizeof(float));
sub(reg_boxes_coord1, step * sizeof(float));
sub(reg_boxes_coord2, step * sizeof(float));
sub(reg_boxes_coord3, step * sizeof(float));
sub(reg_boxes_coord0, vector_step * sizeof(float));
sub(reg_boxes_coord1, vector_step * sizeof(float));
sub(reg_boxes_coord2, vector_step * sizeof(float));
sub(reg_boxes_coord3, vector_step * sizeof(float));
// iou result is in vmm_temp3
iou(step);
iou(vector_step);
sub(reg_boxes_num, step);
sub(reg_boxes_num, vector_step);
suppressed_by_iou(false);
@ -236,21 +238,20 @@ private:
}
L(main_loop_end_label_hard);
step = 1;
L(tail_loop_label_hard);
{
cmp(reg_boxes_num, 1);
jl(terminate_label_hard, T_NEAR);
sub(reg_boxes_coord0, step * sizeof(float));
sub(reg_boxes_coord1, step * sizeof(float));
sub(reg_boxes_coord2, step * sizeof(float));
sub(reg_boxes_coord3, step * sizeof(float));
sub(reg_boxes_coord0, scalar_step * sizeof(float));
sub(reg_boxes_coord1, scalar_step * sizeof(float));
sub(reg_boxes_coord2, scalar_step * sizeof(float));
sub(reg_boxes_coord3, scalar_step * sizeof(float));
// iou result is in vmm_temp3
iou(step);
iou(scalar_step);
sub(reg_boxes_num, step);
sub(reg_boxes_num, scalar_step);
suppressed_by_iou(true);
@ -267,7 +268,6 @@ private:
inline void soft_nms() {
uni_vbroadcastss(vmm_scale, ptr[reg_scale]);
int step = vlen / sizeof(float);
Xbyak::Label main_loop_label;
Xbyak::Label main_loop_end_label;
Xbyak::Label tail_loop_label;
@ -277,17 +277,17 @@ private:
Xbyak::Label tail_loop_label_soft;
L(main_loop_label);
{
cmp(reg_boxes_num, step);
cmp(reg_boxes_num, vector_step);
jl(main_loop_end_label, T_NEAR);
sub(reg_boxes_coord0, step * sizeof(float));
sub(reg_boxes_coord1, step * sizeof(float));
sub(reg_boxes_coord2, step * sizeof(float));
sub(reg_boxes_coord3, step * sizeof(float));
sub(reg_boxes_coord0, vector_step * sizeof(float));
sub(reg_boxes_coord1, vector_step * sizeof(float));
sub(reg_boxes_coord2, vector_step * sizeof(float));
sub(reg_boxes_coord3, vector_step * sizeof(float));
// result(iou and weight) is in vmm_temp3
iou(step);
sub(reg_boxes_num, step);
iou(vector_step);
sub(reg_boxes_num, vector_step);
// soft suppressed by iou_threshold
if (jcp.is_soft_suppressed_by_iou) {
@ -327,19 +327,18 @@ private:
}
L(main_loop_end_label);
step = 1;
L(tail_loop_label);
{
cmp(reg_boxes_num, 1);
jl(terminate_label, T_NEAR);
sub(reg_boxes_coord0, step * sizeof(float));
sub(reg_boxes_coord1, step * sizeof(float));
sub(reg_boxes_coord2, step * sizeof(float));
sub(reg_boxes_coord3, step * sizeof(float));
sub(reg_boxes_coord0, scalar_step * sizeof(float));
sub(reg_boxes_coord1, scalar_step * sizeof(float));
sub(reg_boxes_coord2, scalar_step * sizeof(float));
sub(reg_boxes_coord3, scalar_step * sizeof(float));
iou(step);
sub(reg_boxes_num, step);
iou(scalar_step);
sub(reg_boxes_num, scalar_step);
// soft suppressed by iou_threshold
if (jcp.is_soft_suppressed_by_iou) {
@ -427,8 +426,11 @@ private:
inline void iou(int ele_num) {
auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_dst) {
if (ele_num != scalar_step && ele_num != vector_step)
IE_THROW() << "NMS JIT implementation supports load emitter with only element count scalar_step or vector_step! Get: " << ele_num;
const auto& load_emitter = ele_num == 1 ? load_scalar_emitter : load_vector_emitter;
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_dst.getIdx())},
std::make_shared<load_emitter_context>(Precision::FP32, Precision::FP32, ele_num),
{}, {load_pool_gpr_idxs});
};
load(reg_boxes_coord0, vmm_boxes_coord0);

View File

@ -46,9 +46,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji
};
void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa));
store_emitter.reset(new jit_store_emitter(this, isa));
this->preamble();
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
@ -65,8 +62,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji
this->postamble();
load_emitter->emit_data();
store_emitter->emit_data();
emit_emitters_data();
}
private:
@ -107,10 +103,9 @@ private:
// [1] for reg_dst
Xmm xmm_args_pool = Xmm(15);
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::vector<size_t> load_pool_gpr_idxs;
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::vector<size_t> load_pool_gpr_idxs;
std::vector<size_t> store_pool_gpr_idxs;
std::vector<size_t> store_pool_vec_idxs;
@ -157,6 +152,57 @@ private:
reg64_t reg_params = abi_param1;
void emit_emitters_data() {
for (const auto& emitter : emitters) {
emitter.second->emit_data();
}
}
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, jcp_.data_prc, Precision::FP32, elt_num, offset);
}
inline void load_buffer(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, Precision::FP32, Precision::FP32, elt_num, offset);
}
inline void load_idx(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, Precision::I32, Precision::I32, elt_num, offset);
}
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.data_prc, elt_num, offset);
}
inline void store_buffer(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::FP32, elt_num, offset);
}
inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num));
}
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), static_cast<size_t>(offset)},
{static_cast<size_t>(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs});
}
inline void emit_store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num));
}
// for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm_dst.getIdx()) };
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
emitters[seed]->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(offset)},
{static_cast<size_t>(reg_dst.getIdx())},
{local_store_pool_vec_idxs}, {store_pool_gpr_idxs});
}
void roi_align_cgather() {
mov(reg_src_address, ptr[reg_params + GET_OFF(src)]);
mov(reg_weights, ptr[reg_params + GET_OFF(weights)]);
@ -180,23 +226,6 @@ private:
imul(reg_src_stride, reg_src_stride, jcp_.data_size);
}
auto store = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) {
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.data_prc, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
};
auto load_buf = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) {
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
std::make_shared<load_emitter_context>(Precision::FP32, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
};
auto store_buf = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) {
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, Precision::FP32, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
};
// out loop for samples in bin
Xbyak::Label out_loop_label;
Xbyak::Label out_loop_end_label;
@ -228,13 +257,13 @@ private:
generate_samples(v_step);
// now this sample value across channel reside in vmm_sample
// compute with other samples in vmm_buf
load_buf(reg_buf, vmm_buf, v_step);
load_buffer(reg_buf, vmm_buf, v_step);
if (jcp_.alg == Algorithm::ROIAlignAvg) {
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
} else {
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
}
store_buf(vmm_buf, reg_buf, v_step);
store_buffer(vmm_buf, reg_buf, v_step);
if ((isa == cpu::x64::sse41) && (jcp_.layout == ROIAlignLayoutType::blk)) {
add(reg_src0, x_step * jcp_.data_size);
@ -244,13 +273,13 @@ private:
add(reg_buf, x_step * sizeof(float));
generate_samples(x_step);
load_buf(reg_buf, vmm_buf, x_step);
load_buffer(reg_buf, vmm_buf, x_step);
if (jcp_.alg == Algorithm::ROIAlignAvg) {
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
} else {
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
}
store_buf(vmm_buf, reg_buf, x_step);
store_buffer(vmm_buf, reg_buf, x_step);
sub(reg_src0, x_step * jcp_.data_size);
sub(reg_src1, x_step * jcp_.data_size);
@ -280,13 +309,13 @@ private:
jl(in_loop_tail_end_label, T_NEAR);
generate_samples(tail_step);
load_buf(reg_buf, vmm_buf, tail_step);
load_buffer(reg_buf, vmm_buf, tail_step);
if (jcp_.alg == Algorithm::ROIAlignAvg) {
uni_vaddps(vmm_buf, vmm_buf, vmm_sample);
} else {
uni_vmaxps(vmm_buf, vmm_buf, vmm_sample);
}
store_buf(vmm_buf, reg_buf, tail_step);
store_buffer(vmm_buf, reg_buf, tail_step);
int tail_src_stride = tail_step * jcp_.data_size;
add(reg_src0, tail_src_stride);
@ -333,7 +362,7 @@ private:
cmp(reg_work_amount, v_step);
jl(store_loop_main_end_label, T_NEAR);
load_buf(reg_buf, vmm_buf, v_step);
load_buffer(reg_buf, vmm_buf, v_step);
if (jcp_.alg == Algorithm::ROIAlignAvg) {
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
}
@ -343,7 +372,7 @@ private:
add(reg_buf, x_step * sizeof(float));
add(reg_dst, x_step * jcp_.data_size);
load_buf(reg_buf, vmm_buf, x_step);
load_buffer(reg_buf, vmm_buf, x_step);
if (jcp_.alg == Algorithm::ROIAlignAvg) {
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
}
@ -369,7 +398,7 @@ private:
cmp(reg_work_amount, tail_step);
jl(store_loop_tail_end_label, T_NEAR);
load_buf(reg_buf, vmm_buf, tail_step);
load_buffer(reg_buf, vmm_buf, tail_step);
if (jcp_.alg == Algorithm::ROIAlignAvg) {
uni_vmulps(vmm_buf, vmm_buf, vmm_scale);
}
@ -402,12 +431,6 @@ private:
}
void generate_samples(int num) {
auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) {
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
std::make_shared<load_emitter_context>(jcp_.data_prc, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
};
uni_vpxor(vmm_sample, vmm_sample, vmm_sample);
load(reg_src0, vmm_src0, num);
uni_vfmadd231ps(vmm_sample, vmm_src0, vmm_weights0);
@ -432,12 +455,6 @@ private:
uni_vbroadcastss(vmm_scale, ptr[reg_tmp_64]);
}
auto load_idx = [&](Xbyak::Reg64 reg_idx, Vmm vmm_idx, int elt_num) {
load_emitter->emit_code({static_cast<size_t>(reg_idx.getIdx())}, {static_cast<size_t>(vmm_idx.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, elt_num),
{}, {load_pool_gpr_idxs});
};
Xbyak::Label main_loop_label;
Xbyak::Label main_loop_end_label;
Xbyak::Label tail_loop_label;

View File

@ -48,8 +48,9 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
};
void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa));
store_emitter.reset(new jit_store_emitter(this, isa));
load_emitter.reset(new jit_load_emitter(this, isa, jpp_.src_prc, Precision::FP32, step));
store_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jpp_.dst_prc, step));
store_empty_roi_emitter.reset(new jit_store_emitter(this, isa, jpp_.src_prc, jpp_.dst_prc, step));
this->preamble();
@ -93,6 +94,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
load_emitter->emit_data();
store_emitter->emit_data();
store_empty_roi_emitter->emit_data();
}
private:
@ -114,6 +116,7 @@ private:
std::vector<size_t> load_pool_gpr_idxs;
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::unique_ptr<jit_store_emitter> store_empty_roi_emitter = nullptr;
std::vector<size_t> store_pool_gpr_idxs;
std::vector<size_t> store_pool_vec_idxs;
@ -147,6 +150,12 @@ private:
Xbyak::Reg64 reg_load_table = r15;
Xbyak::Reg64 reg_load_store_mask = abi_param1;
std::vector<size_t> get_local_store_pool_vec_idxs(Vmm vmm) const {
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm.getIdx()) };
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
return local_store_pool_vec_idxs;
}
void roi_pool_max(int c_blocks) {
Label h_loop_label;
Label w_loop_label;
@ -157,8 +166,7 @@ private:
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_max = get_acc_reg(i);
load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx())}, {static_cast<size_t>(vmm_max.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx()), static_cast<size_t>(i * src_c_off)}, {static_cast<size_t>(vmm_max.getIdx())},
{}, load_pool_gpr_idxs);
}
@ -171,9 +179,8 @@ private:
Vmm vmm_max = get_acc_reg(i);
Vmm vmm_src = get_src_reg(i);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, i * src_c_off),
{}, load_pool_gpr_idxs);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx()), static_cast<size_t>(i * src_c_off)},
{static_cast<size_t>(vmm_src.getIdx())}, {}, load_pool_gpr_idxs);
if (isa == cpu::x64::sse41) {
movups(vmm_mask, vmm_max);
@ -206,9 +213,8 @@ private:
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_dst = get_acc_reg(i);
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, i * dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(i * dst_c_off)}, {static_cast<size_t>(reg_output.getIdx())},
get_local_store_pool_vec_idxs(vmm_dst), store_pool_gpr_idxs);
}
}
@ -225,27 +231,22 @@ private:
for (int i = 0; i < c_blocks; i++) {
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size();
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, src_c_off);
mov(aux_reg_input, reg_input);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src00.getIdx())},
load_context,
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src00.getIdx())},
{}, load_pool_gpr_idxs);
add(aux_reg_input, reg_xoff);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src01.getIdx())},
load_context,
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src01.getIdx())},
{}, load_pool_gpr_idxs);
add(aux_reg_input, reg_yoff);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src11.getIdx())},
load_context,
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src11.getIdx())},
{}, load_pool_gpr_idxs);
sub(aux_reg_input, reg_xoff);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src10.getIdx())},
load_context,
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx()), static_cast<size_t>(src_c_off)}, {static_cast<size_t>(vmm_src10.getIdx())},
{}, load_pool_gpr_idxs);
uni_vsubps(vmm_src01, vmm_src01, vmm_src00);
@ -259,9 +260,8 @@ private:
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx()), static_cast<size_t>(dst_c_off)}, {static_cast<size_t>(reg_output.getIdx())},
get_local_store_pool_vec_idxs(vmm_src11), store_pool_gpr_idxs);
}
}
@ -270,9 +270,8 @@ private:
const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size();
for (int i = 0; i < c_blocks; i++) {
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
store_empty_roi_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx()), static_cast<size_t>(i * dst_c_off)},
{static_cast<size_t>(reg_output.getIdx())}, store_pool_vec_idxs, store_pool_gpr_idxs);
}
}

View File

@ -82,9 +82,6 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato
}
void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa));
store_emitter.reset(new jit_store_emitter(this, isa));
this->preamble();
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
@ -123,8 +120,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato
this->postamble();
load_emitter->emit_data();
store_emitter->emit_data();
emit_emitters_data();
if (!shape_agnostic_alg)
prepare_idx_table();
@ -207,9 +203,8 @@ private:
Vmm vmm_zero = Vmm(0); // vmm_zero represents Vmm(0) when isa is avx512_core, otherwise vmm_mask represents Vmm(0)
const Xbyak::Opmask k_mask = Xbyak::Opmask(1);
const int step = vlen / sizeof(float);
const int tail = jcp_.work_amount % step;
const int topk_tail = jcp_.top_k % step;
const int vector_step = vlen / sizeof(float);
const int tail_step = jcp_.work_amount % vector_step;
int blk_stride = 0; // stride of channel blocks at the same space coordinate, only used in blocked layout with topk on channel
unsigned char cmp_flg;
@ -217,13 +212,67 @@ private:
Xbyak::Label l_table;
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
std::vector<size_t> store_pool_gpr_idxs;
std::vector<size_t> load_pool_gpr_idxs;
std::vector<size_t> store_pool_vec_idxs;
void emit_emitters_data() {
for (const auto& emitter : emitters) {
emitter.second->emit_data();
}
}
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, jcp_.precision, Precision::FP32, elt_num, offset);
}
inline void load_i32_f32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, Precision::I32, Precision::FP32, elt_num, offset);
}
inline void load_i32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, Precision::I32, Precision::I32, elt_num, offset);
}
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.precision, elt_num, offset);
}
inline void store_f32_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::I32, elt_num, offset);
}
inline void store_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
emit_store(vmm_dst, reg_dst, Precision::I32, Precision::I32, elt_num, offset);
}
inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num));
}
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), static_cast<size_t>(offset)},
{static_cast<size_t>(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs});
}
inline void emit_store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) {
const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num));
}
// for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm
std::vector<size_t> local_store_pool_vec_idxs = { static_cast<size_t>(vmm_dst.getIdx()) };
local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end());
emitters[seed]->emit_code({static_cast<size_t>(vmm_dst.getIdx()), static_cast<size_t>(offset)},
{static_cast<size_t>(reg_dst.getIdx())},
{local_store_pool_vec_idxs}, {store_pool_gpr_idxs});
}
inline void topk_loop() {
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
@ -253,27 +302,27 @@ private:
Xbyak::Label topk_main_loop_end_label;
L(topk_main_loop_label);
{
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(topk_main_loop_end_label, T_NEAR);
topk_bitonic(step);
topk_bitonic(vector_step);
add(reg_src, step * jcp_.data_size);
add(reg_dst, step * jcp_.data_size);
add(reg_dst_idx, step * sizeof(int));
sub(reg_work_amount, step);
add(reg_src, vector_step * jcp_.data_size);
add(reg_dst, vector_step * jcp_.data_size);
add(reg_dst_idx, vector_step * sizeof(int));
sub(reg_work_amount, vector_step);
jmp(topk_main_loop_label, T_NEAR);
}
L(topk_main_loop_end_label);
// tail
if (tail) {
if (tail_step) {
Xbyak::Label topk_tail_loop_end_label;
cmp(reg_work_amount, tail);
cmp(reg_work_amount, tail_step);
jl(topk_tail_loop_end_label, T_NEAR);
topk_bitonic(tail);
topk_bitonic(tail_step);
L(topk_tail_loop_end_label);
}
@ -282,19 +331,11 @@ private:
inline void topk_bitonic(int elt_num) {
// src => prc
for (int i = 0; i < jcp_.axis_dim; i++) {
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
{}, {load_pool_gpr_idxs});
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_prc.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
load(reg_src, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size);
store(vmm_tmp, reg_prc, elt_num, i * jcp_.sort_stride * jcp_.data_size);
load_emitter->emit_code({static_cast<size_t>(reg_table.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, elt_num, i * vlen),
{}, {load_pool_gpr_idxs});
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_prc_idx.getIdx())},
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
load_i32(reg_table, vmm_tmp, elt_num, i * vlen);
store_i32(vmm_tmp, reg_prc_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
}
// sort
@ -305,19 +346,11 @@ private:
// prc => dst
for (int i = 0; i < jcp_.top_k; i++) {
load_emitter->emit_code({static_cast<size_t>(reg_prc.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
{}, {load_pool_gpr_idxs});
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
load(reg_prc, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size);
store(vmm_tmp, reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size);
load_emitter->emit_code({static_cast<size_t>(reg_prc_idx.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
{}, {load_pool_gpr_idxs});
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_dst_idx.getIdx())},
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
load_i32(reg_prc_idx, vmm_tmp, elt_num, i * jcp_.sort_stride * sizeof(int));
store_i32(vmm_tmp, reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
}
}
@ -330,46 +363,46 @@ private:
Xbyak::Label topk_main_loop_end_label;
L(topk_main_loop_label);
{
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(topk_main_loop_end_label, T_NEAR);
// src => prc
bitonic_BLK_on_channel_load(step);
bitonic_BLK_on_channel_load(vector_step);
// sort
bitonic_sort_vector(step);
bitonic_sort_vector(vector_step);
if (jcp_.sort_index) {
bitonic_sort_vector(step, false);
bitonic_sort_vector(vector_step, false);
}
// prc => dst
bitonic_BLK_on_channel_store(step);
bitonic_BLK_on_channel_store(vector_step);
add(reg_src, step * jcp_.blk_size * jcp_.data_size);
add(reg_dst, step * jcp_.blk_size * jcp_.data_size);
add(reg_dst_idx, step * jcp_.blk_size * sizeof(int));
sub(reg_work_amount, step);
add(reg_src, vector_step * jcp_.blk_size * jcp_.data_size);
add(reg_dst, vector_step * jcp_.blk_size * jcp_.data_size);
add(reg_dst_idx, vector_step * jcp_.blk_size * sizeof(int));
sub(reg_work_amount, vector_step);
jmp(topk_main_loop_label, T_NEAR);
}
L(topk_main_loop_end_label);
// tail exists because working buffer has planar layout, though source buffer has blocked layout)
if (tail) {
if (tail_step) {
Xbyak::Label topk_tail_loop_end_label;
cmp(reg_work_amount, tail);
cmp(reg_work_amount, tail_step);
jl(topk_tail_loop_end_label, T_NEAR);
// src => prc
bitonic_BLK_on_channel_load(tail);
bitonic_BLK_on_channel_load(tail_step);
bitonic_sort_vector(tail);
bitonic_sort_vector(tail_step);
if (jcp_.sort_index) {
bitonic_sort_vector(tail, false);
bitonic_sort_vector(tail_step, false);
}
// prc => dst
bitonic_BLK_on_channel_store(tail);
bitonic_BLK_on_channel_store(tail_step);
L(topk_tail_loop_end_label);
}
@ -437,40 +470,30 @@ private:
inline void bitonic_swap_vector(int elt_num, bool cmp_val = true) {
bitonic_get_addr(reg_prc, jcp_.data_size, 0);
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_val_l.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load(reg_aux_idx, vmm_val_l, elt_num);
bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int));
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_val_r.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load(reg_aux_idx, vmm_val_r, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_idx_l.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load_i32_f32(reg_aux_idx, vmm_idx_l, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
load_emitter->emit_code({static_cast<size_t>(reg_aux_idx.getIdx())}, {static_cast<size_t>(vmm_idx_r.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load_i32_f32(reg_aux_idx, vmm_idx_r, elt_num);
swap_vector(vmm_val_l, vmm_idx_l, vmm_val_r, vmm_idx_r, cmp_val);
bitonic_get_addr(reg_prc, jcp_.data_size, 0);
store_emitter->emit_code({static_cast<size_t>(vmm_val_l.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store(vmm_val_l, reg_aux_idx, elt_num);
bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int));
store_emitter->emit_code({static_cast<size_t>(vmm_val_r.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store(vmm_val_r, reg_aux_idx, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
store_emitter->emit_code({static_cast<size_t>(vmm_idx_l.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store_f32_i32(vmm_idx_l, reg_aux_idx, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
store_emitter->emit_code({static_cast<size_t>(vmm_idx_r.getIdx())}, {static_cast<size_t>(reg_aux_idx.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store_f32_i32(vmm_idx_r, reg_aux_idx, elt_num);
}
inline void topk_heap_sorting() {
@ -480,9 +503,9 @@ private:
// init dst
mov(reg_i, 0);
sub(reg_heap_top_k, step);
topk_heap_load(reg_heap_k_sub_step, step);
add(reg_heap_top_k, step);
sub(reg_heap_top_k, vector_step);
topk_heap_load(reg_heap_k_sub_step, vector_step);
add(reg_heap_top_k, vector_step);
topk_heap_load(reg_heap_top_k, 1);
mov(reg_zero, 0);
@ -579,7 +602,7 @@ private:
Xbyak::Label topk_init_loop_end_label;
L(topk_init_loop_label);
{
if (s == step) {
if (s == vector_step) {
cmp(reg_i, reg_end);
jg(topk_init_loop_end_label, T_NEAR);
} else {
@ -588,25 +611,18 @@ private:
}
get_addr_by_reg_idx(reg_heap_outer_aux, reg_src, reg_i, jcp_.data_size);
load_emitter->emit_code({static_cast<size_t>(reg_heap_outer_aux.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, s),
{}, {load_pool_gpr_idxs});
load(reg_heap_outer_aux, vmm_tmp, s);
get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst, reg_i, jcp_.data_size);
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_heap_outer_aux.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, s),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
if (s == step) {
store(vmm_tmp, reg_heap_outer_aux, s);
if (s == vector_step) {
table_to_vmm(vmm_tmp, reg_heap_seq_idx, reg_i, 0, sizeof(int));
} else {
get_addr_by_reg_idx(reg_heap_outer_aux, reg_heap_seq_idx, reg_i, sizeof(int));
load_emitter->emit_code({static_cast<size_t>(reg_heap_outer_aux.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::I32, 1),
{}, {load_pool_gpr_idxs});
load_i32(reg_heap_outer_aux, vmm_tmp, 1);
}
get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst_idx, reg_i, sizeof(int));
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_heap_outer_aux.getIdx())},
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, s),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store_i32(vmm_tmp, reg_heap_outer_aux, s);
add(reg_i, s);
jmp(topk_init_loop_label, T_NEAR);
@ -822,19 +838,19 @@ private:
Xbyak::Label topk_main_loop_end_label;
L(topk_main_loop_label);
{
cmp(reg_work_amount, step);
cmp(reg_work_amount, vector_step);
jl(topk_main_loop_end_label, T_NEAR);
if (jcp_.bubble_inplace) {
topk_bubble_inplace(step);
topk_bubble_inplace(vector_step);
} else {
topk_bubble(step);
topk_bubble(vector_step);
}
add(reg_src, step * jcp_.data_size);
add(reg_dst, step * jcp_.data_size);
add(reg_dst_idx, step * sizeof(int));
sub(reg_work_amount, step);
add(reg_src, vector_step * jcp_.data_size);
add(reg_dst, vector_step * jcp_.data_size);
add(reg_dst_idx, vector_step * sizeof(int));
sub(reg_work_amount, vector_step);
jmp(topk_main_loop_label, T_NEAR);
}
@ -842,12 +858,12 @@ private:
// tail
if (jcp_.bubble_inplace) {
if (tail) {
if (tail_step) {
Xbyak::Label topk_tail_loop_end_label;
cmp(reg_work_amount, tail);
cmp(reg_work_amount, tail_step);
jl(topk_tail_loop_end_label, T_NEAR);
topk_bubble_inplace(tail);
topk_bubble_inplace(tail_step);
L(topk_tail_loop_end_label);
}
@ -1025,19 +1041,13 @@ private:
je(topk_init_loop_end_label, T_NEAR);
get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i);
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_tmp.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load(reg_tmp, vmm_tmp, elt_num);
get_addr_by_reg_idx(reg_tmp, reg_dst, reg_block_sort_stride_byte, reg_i);
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store(vmm_tmp, reg_tmp, elt_num);
table_to_vmm(vmm_tmp, reg_bubble_block_idx, reg_i, 0, vlen);
get_addr_by_reg_idx(reg_tmp, reg_dst_idx, reg_block_sort_stride_byte, sizeof(int) / jcp_.data_size, reg_i);
store_emitter->emit_code({static_cast<size_t>(vmm_tmp.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
std::make_shared<store_emitter_context>(Precision::I32, Precision::I32, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store_i32(vmm_tmp, reg_tmp, elt_num);
add(reg_i, 1);
jmp(topk_init_loop_label, T_NEAR);
@ -1057,9 +1067,7 @@ private:
je(topk_update_loop_end_label, T_NEAR);
get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i);
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_val_r.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load(reg_tmp, vmm_val_r, elt_num);
table_to_vmm(vmm_idx_r, reg_bubble_block_idx, reg_i, 0, vlen);
uni_vcvtdq2ps(vmm_idx_r, vmm_idx_r);
@ -1152,9 +1160,7 @@ private:
inline void topk_bubble_inplace(int elt_num) {
// load
for (int i = 0; i < jcp_.top_k; i++) {
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(i).getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
{}, {load_pool_gpr_idxs});
load(reg_src, vmm_val(i), elt_num, i * jcp_.sort_stride * jcp_.data_size);
uni_vmovdqu(vmm_idx(i), table_val(i));
uni_vcvtdq2ps(vmm_idx(i), vmm_idx(i));
}
@ -1165,9 +1171,7 @@ private:
}
}
for (int i = jcp_.top_k; i < jcp_.axis_dim; i++) {
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(jcp_.top_k).getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size),
{}, {load_pool_gpr_idxs});
load(reg_src, vmm_val(jcp_.top_k), elt_num, i * jcp_.sort_stride * jcp_.data_size);
uni_vmovdqu(vmm_idx(jcp_.top_k), table_val(i));
uni_vcvtdq2ps(vmm_idx(jcp_.top_k), vmm_idx(jcp_.top_k));
for (int j = jcp_.top_k; j > 0; j--) {
@ -1183,12 +1187,8 @@ private:
}
// store
for (int i = 0; i < jcp_.top_k; i++) {
store_emitter->emit_code({static_cast<size_t>(vmm_val(i).getIdx())}, {static_cast<size_t>(reg_dst.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store_emitter->emit_code({static_cast<size_t>(vmm_idx(i).getIdx())}, {static_cast<size_t>(reg_dst_idx.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store(vmm_val(i), reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size);
store_f32_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
}
}
@ -1211,15 +1211,11 @@ private:
L(topk_load_sort_label);
{
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(0).getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step, 0),
{}, {load_pool_gpr_idxs});
load(reg_src, vmm_val(0), vector_step, 0);
uni_vmovdqu(vmm_idx(0), table_bubble_seq_idx(0));
uni_vcvtdq2ps(vmm_idx(0), vmm_idx(0));
if (isa == cpu::x64::sse41) {
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_val(1).getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step, 4 * jcp_.data_size),
{}, {load_pool_gpr_idxs});
load(reg_src, vmm_val(1), vector_step, 4 * jcp_.data_size);
uni_vmovdqu(vmm_idx(1), table_bubble_seq_idx(4));
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
@ -1235,17 +1231,13 @@ private:
jg(topk_iter_end_label, T_NEAR);
get_addr_by_reg_idx(reg_aux, reg_src, reg_i, jcp_.data_size, reg_seq_sort_stride);
load_emitter->emit_code({static_cast<size_t>(reg_aux.getIdx())}, {static_cast<size_t>(vmm_val(1).getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step),
{}, {load_pool_gpr_idxs});
load(reg_aux, vmm_val(1), vector_step);
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int));
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
if (isa == cpu::x64::sse41) {
add(reg_aux, 4 * jcp_.data_size);
load_emitter->emit_code({static_cast<size_t>(reg_aux.getIdx())}, {static_cast<size_t>(vmm_val(1).getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, step),
{}, {load_pool_gpr_idxs});
load(reg_aux, vmm_val(1), vector_step);
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 4, sizeof(int));
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
@ -1538,16 +1530,13 @@ private:
// load l
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst);
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_val_l.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load(reg_tmp, vmm_val_l, elt_num);
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_idx_l.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load_i32_f32(reg_tmp, vmm_idx_l, elt_num);
// load r
Xbyak::Label topk_load_jmp_label;
@ -1557,16 +1546,14 @@ private:
add(reg_tmp_64, reg_block_sort_stride_byte);
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst);
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_val_r.getIdx())},
std::make_shared<load_emitter_context>(jcp_.precision, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load(reg_tmp, vmm_val_r, elt_num);
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
load_emitter->emit_code({static_cast<size_t>(reg_tmp.getIdx())}, {static_cast<size_t>(vmm_idx_r.getIdx())},
std::make_shared<load_emitter_context>(Precision::I32, Precision::FP32, elt_num),
{}, {load_pool_gpr_idxs});
load_i32_f32(reg_tmp, vmm_idx_r, elt_num);
sub(reg_tmp_64, reg_block_sort_stride_byte);
}
L(topk_load_jmp_label);
@ -1576,16 +1563,13 @@ private:
// store l
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst);
store_emitter->emit_code({static_cast<size_t>(vmm_val_l.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store(vmm_val_l, reg_tmp, elt_num);
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
store_emitter->emit_code({static_cast<size_t>(vmm_idx_l.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store_f32_i32(vmm_idx_l, reg_tmp, elt_num);
// store r
Xbyak::Label topk_store_jmp_label;
@ -1595,16 +1579,13 @@ private:
add(reg_tmp_64, reg_block_sort_stride_byte);
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst);
store_emitter->emit_code({static_cast<size_t>(vmm_val_r.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jcp_.precision, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store(vmm_val_r, reg_tmp, elt_num);
reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size);
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
store_emitter->emit_code({static_cast<size_t>(vmm_idx_r.getIdx())}, {static_cast<size_t>(reg_tmp.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, Precision::I32, elt_num),
{store_pool_vec_idxs}, {store_pool_gpr_idxs});
store_f32_i32(vmm_idx_r, reg_tmp, elt_num);
}
L(topk_store_jmp_label);
}

View File

@ -133,6 +133,11 @@ InferenceEngine::Precision type2precision<uint8_t>() {
return InferenceEngine::Precision::U8;
}
template<>
InferenceEngine::Precision type2precision<int8_t>() {
return InferenceEngine::Precision::I8;
}
cpu_isa_t get_current_isa() {
if (mayiuse(cpu_isa_t::avx512_core))
return cpu_isa_t::avx512_core;
@ -212,9 +217,7 @@ const void * consts_table::store(const void *data, size_t size) {
} // namespace internal
jit_kernel::jit_kernel()
: jit_generator()
, _load_emitter(this, internal::get_current_isa())
, _store_emitter(this, internal::get_current_isa()) {
: jit_generator() {
_free_rmmregs.reserve(16);
_free_rmmregs.reserve(16);
@ -297,10 +300,10 @@ void jit_kernel::free<Zmm>(const Zmm & reg) {
void jit_kernel::postamble() {
jit_generator::postamble();
if (_is_load_emitter_used)
_load_emitter.emit_data();
if (_is_store_emitter_used)
_store_emitter.emit_data();
for (const auto& emitter : _emitters) {
if (emitter.second)
emitter.second->emit_data();
}
}
const AddressFrame & jit_kernel::address_frame(size_t size) const {

View File

@ -697,11 +697,8 @@ struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator {
private:
reg_indices _free_x64regs;
reg_indices _free_rmmregs;
bool _is_load_emitter_used = false;
bool _is_store_emitter_used = false;
jit_load_emitter _load_emitter;
jit_store_emitter _store_emitter;
internal::consts_table _consts;
std::unordered_map<size_t, std::unique_ptr<jit_emitter>> _emitters;
};
template<typename T>
@ -746,17 +743,18 @@ void jit_kernel::load(const variable<DstT[N]> & dst, const variable<SrcT> & src,
const std::vector<size_t> pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end());
const std::vector<size_t> pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end());
_load_emitter.emit_code(
const auto src_prc = internal::type2precision<src_type>();
const auto dst_prc = internal::type2precision<dst_type>();
const auto key = load_emitter_params(src_prc, dst_prc, length).hash();
if (!_emitters[key]) {
_emitters[key].reset(new jit_load_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length));
}
_emitters[key]->emit_code(
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(src).getIdx()) },
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(dst).getIdx()) },
std::make_shared<load_emitter_context>(
internal::type2precision<src_type>(),
internal::type2precision<dst_type>(),
static_cast<int>(length)),
pool_vec_idxs,
pool_gpr_idxs);
_is_load_emitter_used = true;
}
template<typename DstT, size_t N, typename SrcT>
@ -788,17 +786,18 @@ void jit_kernel::store(const variable<DstT> & dst, const variable<SrcT[N]> & src
const std::vector<size_t> pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end());
const std::vector<size_t> pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end());
_store_emitter.emit_code(
const auto src_prc = internal::type2precision<src_type>();
const auto dst_prc = internal::type2precision<dst_type>();
const auto key = store_emitter_params(src_prc, dst_prc, length).hash();
if (!_emitters[key]) {
_emitters[key].reset(new jit_store_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length));
}
_emitters[key]->emit_code(
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(src).getIdx()) },
{ static_cast<size_t>(static_cast<const Xbyak::Operand&>(dst).getIdx()) },
std::make_shared<store_emitter_context>(
internal::type2precision<src_type>(),
internal::type2precision<dst_type>(),
static_cast<int>(length)),
pool_vec_idxs,
pool_gpr_idxs);
_is_store_emitter_used = true;
}
template<typename DstT, typename SrcT, size_t N>

View File

@ -318,6 +318,7 @@ private:
};
TEST(JitKernel, variable_load_and_store) {
{
jit_variable_load_store_test_kernel<uint8_t, float> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16>();
@ -328,6 +329,20 @@ TEST(JitKernel, variable_load_and_store) {
if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4>();
}
}
{
jit_variable_load_store_test_kernel<int8_t, int8_t> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16>();
}
if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8>();
}
if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4>();
}
}
}
} // namespace