[CPU] Support topk sort for int32 directly (#13448)
This commit is contained in:
@@ -98,6 +98,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato
|
||||
mov(reg_table, l_table);
|
||||
|
||||
data_type = DnnlExtensionUtils::IEPrecisionToDataType(jcp_.precision);
|
||||
precision_in_reg = isFloatCompatible(data_type) ? Precision::FP32 : Precision::I32;
|
||||
if (!shape_agnostic_alg && jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost)
|
||||
blk_stride = jcp_.sort_stride * jcp_.blk_size;
|
||||
|
||||
@@ -131,6 +132,7 @@ private:
|
||||
Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||
size_t vlen = cpu_isa_traits<isa>::vlen;
|
||||
dnnl::memory::data_type data_type;
|
||||
Precision precision_in_reg;
|
||||
|
||||
Xbyak::Address table_val(int index) { return ptr[reg_table + index * vlen]; }
|
||||
Xbyak::Address table_bubble_block_idx(int index) { return ptr[reg_bubble_block_idx + index * vlen]; }
|
||||
@@ -225,11 +227,7 @@ private:
|
||||
}
|
||||
|
||||
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||
emit_load(reg_src, vmm_src, jcp_.precision, Precision::FP32, elt_num, offset);
|
||||
}
|
||||
|
||||
inline void load_i32_f32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||
emit_load(reg_src, vmm_src, Precision::I32, Precision::FP32, elt_num, offset);
|
||||
emit_load(reg_src, vmm_src, jcp_.precision, precision_in_reg, elt_num, offset);
|
||||
}
|
||||
|
||||
inline void load_i32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
|
||||
@@ -237,11 +235,7 @@ private:
|
||||
}
|
||||
|
||||
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||
emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.precision, elt_num, offset);
|
||||
}
|
||||
|
||||
inline void store_f32_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||
emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::I32, elt_num, offset);
|
||||
emit_store(vmm_dst, reg_dst, precision_in_reg, jcp_.precision, elt_num, offset);
|
||||
}
|
||||
|
||||
inline void store_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
|
||||
@@ -443,7 +437,7 @@ private:
|
||||
store_scalar(ptr[reg_prc + (i * jcp_.sort_stride + j) * jcp_.data_size], xmm_tmp, data_type);
|
||||
|
||||
uni_vmovdqu(xmm_tmp, table_val(i));
|
||||
store_scalar(ptr[reg_prc_idx + (i * jcp_.sort_stride + j) * sizeof(int)], xmm_tmp, memory::data_type::s32, false);
|
||||
store_scalar(ptr[reg_prc_idx + (i * jcp_.sort_stride + j) * sizeof(int)], xmm_tmp, memory::data_type::s32);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -476,10 +470,10 @@ private:
|
||||
load(reg_aux_idx, vmm_val_r, elt_num);
|
||||
|
||||
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
|
||||
load_i32_f32(reg_aux_idx, vmm_idx_l, elt_num);
|
||||
load_i32(reg_aux_idx, vmm_idx_l, elt_num);
|
||||
|
||||
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
|
||||
load_i32_f32(reg_aux_idx, vmm_idx_r, elt_num);
|
||||
load_i32(reg_aux_idx, vmm_idx_r, elt_num);
|
||||
|
||||
swap_vector(vmm_val_l, vmm_idx_l, vmm_val_r, vmm_idx_r, cmp_val);
|
||||
|
||||
@@ -490,10 +484,10 @@ private:
|
||||
store(vmm_val_r, reg_aux_idx, elt_num);
|
||||
|
||||
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
|
||||
store_f32_i32(vmm_idx_l, reg_aux_idx, elt_num);
|
||||
store_i32(vmm_idx_l, reg_aux_idx, elt_num);
|
||||
|
||||
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
|
||||
store_f32_i32(vmm_idx_r, reg_aux_idx, elt_num);
|
||||
store_i32(vmm_idx_r, reg_aux_idx, elt_num);
|
||||
}
|
||||
|
||||
inline void topk_heap_sorting() {
|
||||
@@ -788,20 +782,63 @@ private:
|
||||
add(rsp, sizeof(int));
|
||||
}
|
||||
|
||||
inline void heap_cmp_node(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, bool cmp_val = true) {
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
if (cmp_val)
|
||||
vcmpps(k_mask, xmm_val_a, xmm_val_b, heap_cmp_flg);
|
||||
else
|
||||
vcmpps(k_mask, xmm_idx_a, xmm_idx_b, _cmp_lt_os);
|
||||
inline bool is_valid_isa(cpu_isa_t cpu_isa) {
|
||||
return cpu::x64::is_subset(cpu_isa, isa_all) && mayiuse(cpu_isa);
|
||||
}
|
||||
|
||||
inline void uni_vpcmpgtd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
|
||||
const Xbyak::Operand &op) {
|
||||
if (is_valid_isa(cpu::x64::avx)) {
|
||||
vpcmpgtd(x1, x2, op);
|
||||
} else {
|
||||
if (cmp_val)
|
||||
uni_vcmpps(xmm_mask, xmm_val_a, xmm_val_b, heap_cmp_flg);
|
||||
else
|
||||
uni_vcmpps(xmm_mask, xmm_idx_a, xmm_idx_b, _cmp_lt_os);
|
||||
if (x1.getIdx() != x2.getIdx())
|
||||
uni_vmovups(x1, x2);
|
||||
pcmpgtd(x1, op);
|
||||
}
|
||||
}
|
||||
|
||||
inline void uni_vpcmpgtd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
|
||||
const Xbyak::Operand &op) {
|
||||
vpcmpgtd(x1, x2, op);
|
||||
}
|
||||
|
||||
inline void compare_node_xmm(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, Xmm mask,
|
||||
unsigned char val_cmp_flg, unsigned char idx_cmp_flg, bool cmp_val) {
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
if (cmp_val) {
|
||||
if (isFloatCompatible(data_type)) {
|
||||
vcmpps(k_mask, xmm_val_a, xmm_val_b, val_cmp_flg);
|
||||
} else {
|
||||
vpcmpd(k_mask, xmm_val_a, xmm_val_b, val_cmp_flg);
|
||||
}
|
||||
} else {
|
||||
vpcmpd(k_mask, xmm_idx_a, xmm_idx_b, idx_cmp_flg);
|
||||
}
|
||||
} else {
|
||||
if (cmp_val) {
|
||||
if (isFloatCompatible(data_type)) {
|
||||
uni_vcmpps(mask, xmm_val_a, xmm_val_b, val_cmp_flg);
|
||||
} else {
|
||||
if (val_cmp_flg == _cmp_nle_us) {
|
||||
uni_vpcmpgtd(mask, xmm_val_a, xmm_val_b);
|
||||
} else {
|
||||
uni_vpcmpgtd(mask, xmm_val_b, xmm_val_a);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (idx_cmp_flg == _cmp_nle_us) {
|
||||
uni_vpcmpgtd(mask, xmm_idx_a, xmm_idx_b);
|
||||
} else {
|
||||
uni_vpcmpgtd(mask, xmm_idx_b, xmm_idx_a);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void heap_cmp_node(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, bool cmp_val = true) {
|
||||
compare_node_xmm(xmm_val_a, xmm_idx_a, xmm_val_b, xmm_idx_b, xmm_mask, heap_cmp_flg, _cmp_lt_os, cmp_val);
|
||||
}
|
||||
|
||||
// n: node, c: child
|
||||
inline void heap_swap_node(Xmm xmm_val_n, Xmm xmm_idx_n, Xmm xmm_val_c, Xmm xmm_idx_c) {
|
||||
// swap store
|
||||
@@ -1070,7 +1107,6 @@ private:
|
||||
load(reg_tmp, vmm_val_r, elt_num);
|
||||
|
||||
table_to_vmm(vmm_idx_r, reg_bubble_block_idx, reg_i, 0, vlen);
|
||||
uni_vcvtdq2ps(vmm_idx_r, vmm_idx_r);
|
||||
|
||||
sub(rsp, sizeof(int64_t));
|
||||
mov(ptr[rsp], reg_bubble_block_idx);
|
||||
@@ -1162,7 +1198,6 @@ private:
|
||||
for (int i = 0; i < jcp_.top_k; i++) {
|
||||
load(reg_src, vmm_val(i), elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||
uni_vmovdqu(vmm_idx(i), table_val(i));
|
||||
uni_vcvtdq2ps(vmm_idx(i), vmm_idx(i));
|
||||
}
|
||||
// sort
|
||||
for (int i = 0; i < jcp_.top_k - 1; i++) {
|
||||
@@ -1173,7 +1208,6 @@ private:
|
||||
for (int i = jcp_.top_k; i < jcp_.axis_dim; i++) {
|
||||
load(reg_src, vmm_val(jcp_.top_k), elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||
uni_vmovdqu(vmm_idx(jcp_.top_k), table_val(i));
|
||||
uni_vcvtdq2ps(vmm_idx(jcp_.top_k), vmm_idx(jcp_.top_k));
|
||||
for (int j = jcp_.top_k; j > 0; j--) {
|
||||
swap_vector(vmm_val(j - 1), vmm_idx(j - 1), vmm_val(j), vmm_idx(j));
|
||||
}
|
||||
@@ -1188,7 +1222,7 @@ private:
|
||||
// store
|
||||
for (int i = 0; i < jcp_.top_k; i++) {
|
||||
store(vmm_val(i), reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size);
|
||||
store_f32_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
|
||||
store_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1206,18 +1240,15 @@ private:
|
||||
|
||||
load_scalar(xmm_val(0), ptr[reg_src], data_type);
|
||||
uni_vmovss(xmm_idx(0), table_bubble_seq_idx(0));
|
||||
uni_vcvtdq2ps(xmm_idx(0), xmm_idx(0));
|
||||
jmp(topk_load_sort_end_label, T_NEAR);
|
||||
|
||||
L(topk_load_sort_label);
|
||||
{
|
||||
load(reg_src, vmm_val(0), vector_step, 0);
|
||||
uni_vmovdqu(vmm_idx(0), table_bubble_seq_idx(0));
|
||||
uni_vcvtdq2ps(vmm_idx(0), vmm_idx(0));
|
||||
if (isa == cpu::x64::sse41) {
|
||||
load(reg_src, vmm_val(1), vector_step, 4 * jcp_.data_size);
|
||||
uni_vmovdqu(vmm_idx(1), table_bubble_seq_idx(4));
|
||||
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
||||
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
||||
}
|
||||
|
||||
@@ -1233,13 +1264,11 @@ private:
|
||||
get_addr_by_reg_idx(reg_aux, reg_src, reg_i, jcp_.data_size, reg_seq_sort_stride);
|
||||
load(reg_aux, vmm_val(1), vector_step);
|
||||
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int));
|
||||
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
||||
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
||||
if (isa == cpu::x64::sse41) {
|
||||
add(reg_aux, 4 * jcp_.data_size);
|
||||
load(reg_aux, vmm_val(1), vector_step);
|
||||
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 4, sizeof(int));
|
||||
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
|
||||
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
|
||||
}
|
||||
|
||||
@@ -1263,7 +1292,6 @@ private:
|
||||
|
||||
load_scalar(xmm_val(1), ptr[reg_aux], data_type);
|
||||
table_to_xmm(xmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int));
|
||||
uni_vcvtdq2ps(xmm_idx(1), xmm_idx(1));
|
||||
bubble_swap_xmm(xmm_val(0), xmm_idx(0), xmm_val(1), xmm_idx(1));
|
||||
|
||||
add(reg_i, 1);
|
||||
@@ -1359,7 +1387,7 @@ private:
|
||||
|
||||
table_to_xmm(xmm_tmp, reg_bubble_seq_idx, reg_i, 0, sizeof(int));
|
||||
get_addr_by_reg_idx(reg_tmp, reg_dst_idx, reg_aux, sizeof(int) / jcp_.data_size);
|
||||
store_scalar(ptr[reg_tmp], xmm_tmp, memory::data_type::s32, false);
|
||||
store_scalar(ptr[reg_tmp], xmm_tmp, memory::data_type::s32);
|
||||
|
||||
add(reg_sub_idx, 1);
|
||||
cmp(reg_sub_idx, jcp_.blk_size);
|
||||
@@ -1396,7 +1424,6 @@ private:
|
||||
|
||||
load_scalar(xmm_val_r, ptr[reg_aux], data_type);
|
||||
table_to_xmm(xmm_idx_r, reg_bubble_seq_idx, reg_i, 0, sizeof(int));
|
||||
uni_vcvtdq2ps(xmm_idx_r, xmm_idx_r);
|
||||
|
||||
sub(rsp, sizeof(int));
|
||||
mov(ptr[rsp], reg_prc.cvt32());
|
||||
@@ -1491,7 +1518,6 @@ private:
|
||||
int offset = i / jcp_.blk_size * blk_stride + i % jcp_.blk_size;
|
||||
load_scalar(xmm_val(i), ptr[reg_src + offset * jcp_.data_size], data_type);
|
||||
uni_vmovdqu(xmm_idx(i), table_val(i));
|
||||
uni_vcvtdq2ps(xmm_idx(i), xmm_idx(i));
|
||||
}
|
||||
// sort
|
||||
for (int i = 0; i < jcp_.top_k - 1; i++) {
|
||||
@@ -1503,7 +1529,6 @@ private:
|
||||
int offset = i / jcp_.blk_size * blk_stride + i % jcp_.blk_size;
|
||||
load_scalar(xmm_val(jcp_.top_k), ptr[reg_src + offset * jcp_.data_size], data_type);
|
||||
uni_vmovdqu(xmm_idx(jcp_.top_k), table_val(i));
|
||||
uni_vcvtdq2ps(xmm_idx(jcp_.top_k), xmm_idx(jcp_.top_k));
|
||||
for (int j = jcp_.top_k; j > 0; j--) {
|
||||
bubble_swap_xmm(xmm_val(j - 1), xmm_idx(j - 1), xmm_val(j), xmm_idx(j));
|
||||
}
|
||||
@@ -1536,7 +1561,7 @@ private:
|
||||
mov(reg_tmp, reg_tmp_64);
|
||||
add(reg_tmp, reg_dst_idx);
|
||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||
load_i32_f32(reg_tmp, vmm_idx_l, elt_num);
|
||||
load_i32(reg_tmp, vmm_idx_l, elt_num);
|
||||
|
||||
// load r
|
||||
Xbyak::Label topk_load_jmp_label;
|
||||
@@ -1552,7 +1577,7 @@ private:
|
||||
mov(reg_tmp, reg_tmp_64);
|
||||
add(reg_tmp, reg_dst_idx);
|
||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||
load_i32_f32(reg_tmp, vmm_idx_r, elt_num);
|
||||
load_i32(reg_tmp, vmm_idx_r, elt_num);
|
||||
|
||||
sub(reg_tmp_64, reg_block_sort_stride_byte);
|
||||
}
|
||||
@@ -1569,7 +1594,7 @@ private:
|
||||
mov(reg_tmp, reg_tmp_64);
|
||||
add(reg_tmp, reg_dst_idx);
|
||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||
store_f32_i32(vmm_idx_l, reg_tmp, elt_num);
|
||||
store_i32(vmm_idx_l, reg_tmp, elt_num);
|
||||
|
||||
// store r
|
||||
Xbyak::Label topk_store_jmp_label;
|
||||
@@ -1585,18 +1610,15 @@ private:
|
||||
mov(reg_tmp, reg_tmp_64);
|
||||
add(reg_tmp, reg_dst_idx);
|
||||
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
|
||||
store_f32_i32(vmm_idx_r, reg_tmp, elt_num);
|
||||
store_i32(vmm_idx_r, reg_tmp, elt_num);
|
||||
}
|
||||
L(topk_store_jmp_label);
|
||||
}
|
||||
|
||||
inline void swap_vector(Vmm vmm_val_a, Vmm vmm_idx_a, Vmm vmm_val_b, Vmm vmm_idx_b, bool cmp_val = true) {
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
if (cmp_val)
|
||||
vcmpps(k_mask, vmm_val_a, vmm_val_b, cmp_flg);
|
||||
else
|
||||
vcmpps(k_mask, vmm_idx_a, vmm_idx_b, _cmp_nle_us);
|
||||
compare_node_xmm(vmm_val_a, vmm_idx_a, vmm_val_b, vmm_idx_b, vmm_mask, cmp_flg, _cmp_nle_us, cmp_val);
|
||||
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
uni_vmovups(vmm_tmp, vmm_val_a);
|
||||
vblendmps(vmm_val_a | k_mask, vmm_val_a, vmm_val_b);
|
||||
vblendmps(vmm_val_b | k_mask, vmm_val_b, vmm_tmp);
|
||||
@@ -1605,11 +1627,6 @@ private:
|
||||
vblendmps(vmm_idx_a | k_mask, vmm_idx_a, vmm_idx_b);
|
||||
vblendmps(vmm_idx_b | k_mask, vmm_idx_b, vmm_tmp);
|
||||
} else {
|
||||
if (cmp_val)
|
||||
uni_vcmpps(vmm_mask, vmm_val_a, vmm_val_b, cmp_flg);
|
||||
else
|
||||
uni_vcmpps(vmm_mask, vmm_idx_a, vmm_idx_b, _cmp_nle_us);
|
||||
|
||||
uni_vmovups(vmm_tmp, vmm_val_a);
|
||||
uni_vblendvps(vmm_val_a, vmm_val_a, vmm_val_b, vmm_mask);
|
||||
uni_vblendvps(vmm_val_b, vmm_val_b, vmm_tmp, vmm_mask);
|
||||
@@ -1675,12 +1692,9 @@ private:
|
||||
}
|
||||
|
||||
inline void bubble_swap_xmm(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, bool cmp_val = true) {
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
if (cmp_val)
|
||||
vcmpps(k_mask, xmm_val_a, xmm_val_b, cmp_flg);
|
||||
else
|
||||
vcmpps(k_mask, xmm_idx_a, xmm_idx_b, _cmp_nle_us);
|
||||
compare_node_xmm(xmm_val_a, xmm_idx_a, xmm_val_b, xmm_idx_b, xmm_mask, cmp_flg, _cmp_nle_us, cmp_val);
|
||||
|
||||
if (isa == cpu::x64::avx512_core) {
|
||||
uni_vmovups(xmm_tmp, xmm_val_a);
|
||||
vblendmps(xmm_val_a | k_mask, xmm_val_a, xmm_val_b);
|
||||
vblendmps(xmm_val_b | k_mask, xmm_val_b, xmm_tmp);
|
||||
@@ -1689,11 +1703,6 @@ private:
|
||||
vblendmps(xmm_idx_a | k_mask, xmm_idx_a, xmm_idx_b);
|
||||
vblendmps(xmm_idx_b | k_mask, xmm_idx_b, xmm_tmp);
|
||||
} else {
|
||||
if (cmp_val)
|
||||
uni_vcmpps(xmm_mask, xmm_val_a, xmm_val_b, cmp_flg);
|
||||
else
|
||||
uni_vcmpps(xmm_mask, xmm_idx_a, xmm_idx_b, _cmp_nle_us);
|
||||
|
||||
uni_vmovups(xmm_tmp, xmm_val_a);
|
||||
uni_vblendvps(xmm_val_a, xmm_val_a, xmm_val_b, xmm_mask);
|
||||
uni_vblendvps(xmm_val_b, xmm_val_b, xmm_tmp, xmm_mask);
|
||||
@@ -1704,7 +1713,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt, bool cvt_dt = true) {
|
||||
inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt, bool cvt_dt = false) {
|
||||
switch (src_dt) {
|
||||
case memory::data_type::f32:
|
||||
case memory::data_type::s32:
|
||||
@@ -1731,7 +1740,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt, bool cvt_dt = true) {
|
||||
inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt, bool cvt_dt = false) {
|
||||
if (cvt_dt && !isFloatCompatible(dst_dt)) {
|
||||
uni_vcvtps2dq(xmm_dst, xmm_dst);
|
||||
}
|
||||
|
||||
@@ -148,16 +148,25 @@ protected:
|
||||
tensor = ov::test::utils::create_and_fill_tensor(funcInputs[0].get_element_type(), shape);
|
||||
size_t size = tensor.get_size();
|
||||
|
||||
if (netPrecision == ElementType::f32) {
|
||||
if (netPrecision == ElementType::f32 || netPrecision == ElementType::i32) {
|
||||
std::vector<int> data(size);
|
||||
int start = - static_cast<int>(size / 2);
|
||||
|
||||
// For int32, deliberately set big numbers which are not accurately representable in fp32
|
||||
int start = netPrecision == ElementType::i32 ? pow(2, 30) + 1 : - static_cast<int>(size / 2);
|
||||
std::iota(data.begin(), data.end(), start);
|
||||
std::mt19937 gen(0);
|
||||
std::shuffle(data.begin(), data.end(), gen);
|
||||
|
||||
auto *rawBlobDataPtr = static_cast<float *>(tensor.data());
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
rawBlobDataPtr[i] = static_cast<float>(data[i]);
|
||||
if (netPrecision == ElementType::f32) {
|
||||
auto *rawBlobDataPtr = static_cast<float *>(tensor.data());
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
rawBlobDataPtr[i] = static_cast<float>(data[i]);
|
||||
}
|
||||
} else {
|
||||
auto *rawBlobDataPtr = static_cast<int32_t *>(tensor.data());
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
rawBlobDataPtr[i] = static_cast<int32_t>(data[i]);
|
||||
}
|
||||
}
|
||||
} else if (netPrecision == ElementType::bf16) {
|
||||
size_t O = 1, A = 1, I = 1;
|
||||
@@ -283,6 +292,46 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_dynamic, TopKLayerCPUTest,
|
||||
::testing::ValuesIn(additionalConfig)),
|
||||
TopKLayerCPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<int64_t> k_int32 = {1, 5, 7, 9};
|
||||
|
||||
std::vector<ov::test::InputShape> inputShapes_int32 = {
|
||||
{{}, {{9, 9, 9, 9}}},
|
||||
};
|
||||
|
||||
std::vector<ov::test::InputShape> inputShapesDynamic_int32 = {
|
||||
{{9, {5, 10}, 9, {5, 10}}, {{9, 9, 9, 9}, {9, 10, 9, 10}}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_TopK_int32, TopKLayerCPUTest,
|
||||
::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(k_int32),
|
||||
::testing::ValuesIn(axes),
|
||||
::testing::ValuesIn(modes),
|
||||
::testing::ValuesIn(sortTypes),
|
||||
::testing::Values(ElementType::i32),
|
||||
::testing::Values(ElementType::undefined),
|
||||
::testing::Values(ElementType::undefined),
|
||||
::testing::ValuesIn(inputShapes_int32)),
|
||||
::testing::ValuesIn(filterCPUSpecificParams(cpuParams)),
|
||||
::testing::Values(additionalConfig[0])),
|
||||
TopKLayerCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_TopK_int32_dynamic, TopKLayerCPUTest,
|
||||
::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::Values(1),
|
||||
::testing::ValuesIn(axes),
|
||||
::testing::ValuesIn(modes),
|
||||
::testing::ValuesIn(sortTypes),
|
||||
::testing::Values(ElementType::i32),
|
||||
::testing::Values(ElementType::undefined),
|
||||
::testing::Values(ElementType::undefined),
|
||||
::testing::ValuesIn(inputShapesDynamic_int32)),
|
||||
::testing::ValuesIn(filterCPUSpecificParams(cpuParams)),
|
||||
::testing::Values(additionalConfig[0])),
|
||||
TopKLayerCPUTest::getTestCaseName);
|
||||
|
||||
std::vector<ov::test::InputShape> inputShapes_bubble_BLK_on_channel_horiz = {
|
||||
{{}, {{2, 2, 2, 2}}},
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user