[CPU] Support topk sort for int32 directly (#13448)

This commit is contained in:
Chen Xu
2022-10-19 18:38:33 +08:00
committed by GitHub
parent a25c2ba665
commit 98dbb91af6
2 changed files with 128 additions and 70 deletions

View File

@@ -98,6 +98,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato
mov(reg_table, l_table);
data_type = DnnlExtensionUtils::IEPrecisionToDataType(jcp_.precision);
precision_in_reg = isFloatCompatible(data_type) ? Precision::FP32 : Precision::I32;
if (!shape_agnostic_alg && jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost)
blk_stride = jcp_.sort_stride * jcp_.blk_size;
@@ -131,6 +132,7 @@ private:
Xbyak::Ymm, Xbyak::Zmm>::type;
size_t vlen = cpu_isa_traits<isa>::vlen;
dnnl::memory::data_type data_type;
Precision precision_in_reg;
Xbyak::Address table_val(int index) { return ptr[reg_table + index * vlen]; }
Xbyak::Address table_bubble_block_idx(int index) { return ptr[reg_bubble_block_idx + index * vlen]; }
@@ -225,11 +227,7 @@ private:
}
inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, jcp_.precision, Precision::FP32, elt_num, offset);
}
inline void load_i32_f32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
emit_load(reg_src, vmm_src, Precision::I32, Precision::FP32, elt_num, offset);
emit_load(reg_src, vmm_src, jcp_.precision, precision_in_reg, elt_num, offset);
}
inline void load_i32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) {
@@ -237,11 +235,7 @@ private:
}
inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.precision, elt_num, offset);
}
inline void store_f32_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::I32, elt_num, offset);
emit_store(vmm_dst, reg_dst, precision_in_reg, jcp_.precision, elt_num, offset);
}
inline void store_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) {
@@ -443,7 +437,7 @@ private:
store_scalar(ptr[reg_prc + (i * jcp_.sort_stride + j) * jcp_.data_size], xmm_tmp, data_type);
uni_vmovdqu(xmm_tmp, table_val(i));
store_scalar(ptr[reg_prc_idx + (i * jcp_.sort_stride + j) * sizeof(int)], xmm_tmp, memory::data_type::s32, false);
store_scalar(ptr[reg_prc_idx + (i * jcp_.sort_stride + j) * sizeof(int)], xmm_tmp, memory::data_type::s32);
}
}
}
@@ -476,10 +470,10 @@ private:
load(reg_aux_idx, vmm_val_r, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
load_i32_f32(reg_aux_idx, vmm_idx_l, elt_num);
load_i32(reg_aux_idx, vmm_idx_l, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
load_i32_f32(reg_aux_idx, vmm_idx_r, elt_num);
load_i32(reg_aux_idx, vmm_idx_r, elt_num);
swap_vector(vmm_val_l, vmm_idx_l, vmm_val_r, vmm_idx_r, cmp_val);
@@ -490,10 +484,10 @@ private:
store(vmm_val_r, reg_aux_idx, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), 0);
store_f32_i32(vmm_idx_l, reg_aux_idx, elt_num);
store_i32(vmm_idx_l, reg_aux_idx, elt_num);
bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int));
store_f32_i32(vmm_idx_r, reg_aux_idx, elt_num);
store_i32(vmm_idx_r, reg_aux_idx, elt_num);
}
inline void topk_heap_sorting() {
@@ -788,20 +782,63 @@ private:
add(rsp, sizeof(int));
}
inline void heap_cmp_node(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, bool cmp_val = true) {
if (isa == cpu::x64::avx512_core) {
if (cmp_val)
vcmpps(k_mask, xmm_val_a, xmm_val_b, heap_cmp_flg);
else
vcmpps(k_mask, xmm_idx_a, xmm_idx_b, _cmp_lt_os);
inline bool is_valid_isa(cpu_isa_t cpu_isa) {
return cpu::x64::is_subset(cpu_isa, isa_all) && mayiuse(cpu_isa);
}
inline void uni_vpcmpgtd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
const Xbyak::Operand &op) {
if (is_valid_isa(cpu::x64::avx)) {
vpcmpgtd(x1, x2, op);
} else {
if (cmp_val)
uni_vcmpps(xmm_mask, xmm_val_a, xmm_val_b, heap_cmp_flg);
else
uni_vcmpps(xmm_mask, xmm_idx_a, xmm_idx_b, _cmp_lt_os);
if (x1.getIdx() != x2.getIdx())
uni_vmovups(x1, x2);
pcmpgtd(x1, op);
}
}
inline void uni_vpcmpgtd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
const Xbyak::Operand &op) {
vpcmpgtd(x1, x2, op);
}
inline void compare_node_xmm(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, Xmm mask,
unsigned char val_cmp_flg, unsigned char idx_cmp_flg, bool cmp_val) {
if (isa == cpu::x64::avx512_core) {
if (cmp_val) {
if (isFloatCompatible(data_type)) {
vcmpps(k_mask, xmm_val_a, xmm_val_b, val_cmp_flg);
} else {
vpcmpd(k_mask, xmm_val_a, xmm_val_b, val_cmp_flg);
}
} else {
vpcmpd(k_mask, xmm_idx_a, xmm_idx_b, idx_cmp_flg);
}
} else {
if (cmp_val) {
if (isFloatCompatible(data_type)) {
uni_vcmpps(mask, xmm_val_a, xmm_val_b, val_cmp_flg);
} else {
if (val_cmp_flg == _cmp_nle_us) {
uni_vpcmpgtd(mask, xmm_val_a, xmm_val_b);
} else {
uni_vpcmpgtd(mask, xmm_val_b, xmm_val_a);
}
}
} else {
if (idx_cmp_flg == _cmp_nle_us) {
uni_vpcmpgtd(mask, xmm_idx_a, xmm_idx_b);
} else {
uni_vpcmpgtd(mask, xmm_idx_b, xmm_idx_a);
}
}
}
}
inline void heap_cmp_node(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, bool cmp_val = true) {
compare_node_xmm(xmm_val_a, xmm_idx_a, xmm_val_b, xmm_idx_b, xmm_mask, heap_cmp_flg, _cmp_lt_os, cmp_val);
}
// n: node, c: child
inline void heap_swap_node(Xmm xmm_val_n, Xmm xmm_idx_n, Xmm xmm_val_c, Xmm xmm_idx_c) {
// swap store
@@ -1070,7 +1107,6 @@ private:
load(reg_tmp, vmm_val_r, elt_num);
table_to_vmm(vmm_idx_r, reg_bubble_block_idx, reg_i, 0, vlen);
uni_vcvtdq2ps(vmm_idx_r, vmm_idx_r);
sub(rsp, sizeof(int64_t));
mov(ptr[rsp], reg_bubble_block_idx);
@@ -1162,7 +1198,6 @@ private:
for (int i = 0; i < jcp_.top_k; i++) {
load(reg_src, vmm_val(i), elt_num, i * jcp_.sort_stride * jcp_.data_size);
uni_vmovdqu(vmm_idx(i), table_val(i));
uni_vcvtdq2ps(vmm_idx(i), vmm_idx(i));
}
// sort
for (int i = 0; i < jcp_.top_k - 1; i++) {
@@ -1173,7 +1208,6 @@ private:
for (int i = jcp_.top_k; i < jcp_.axis_dim; i++) {
load(reg_src, vmm_val(jcp_.top_k), elt_num, i * jcp_.sort_stride * jcp_.data_size);
uni_vmovdqu(vmm_idx(jcp_.top_k), table_val(i));
uni_vcvtdq2ps(vmm_idx(jcp_.top_k), vmm_idx(jcp_.top_k));
for (int j = jcp_.top_k; j > 0; j--) {
swap_vector(vmm_val(j - 1), vmm_idx(j - 1), vmm_val(j), vmm_idx(j));
}
@@ -1188,7 +1222,7 @@ private:
// store
for (int i = 0; i < jcp_.top_k; i++) {
store(vmm_val(i), reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size);
store_f32_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
store_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int));
}
}
@@ -1206,18 +1240,15 @@ private:
load_scalar(xmm_val(0), ptr[reg_src], data_type);
uni_vmovss(xmm_idx(0), table_bubble_seq_idx(0));
uni_vcvtdq2ps(xmm_idx(0), xmm_idx(0));
jmp(topk_load_sort_end_label, T_NEAR);
L(topk_load_sort_label);
{
load(reg_src, vmm_val(0), vector_step, 0);
uni_vmovdqu(vmm_idx(0), table_bubble_seq_idx(0));
uni_vcvtdq2ps(vmm_idx(0), vmm_idx(0));
if (isa == cpu::x64::sse41) {
load(reg_src, vmm_val(1), vector_step, 4 * jcp_.data_size);
uni_vmovdqu(vmm_idx(1), table_bubble_seq_idx(4));
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
}
@@ -1233,13 +1264,11 @@ private:
get_addr_by_reg_idx(reg_aux, reg_src, reg_i, jcp_.data_size, reg_seq_sort_stride);
load(reg_aux, vmm_val(1), vector_step);
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int));
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
if (isa == cpu::x64::sse41) {
add(reg_aux, 4 * jcp_.data_size);
load(reg_aux, vmm_val(1), vector_step);
table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 4, sizeof(int));
uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1));
swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1));
}
@@ -1263,7 +1292,6 @@ private:
load_scalar(xmm_val(1), ptr[reg_aux], data_type);
table_to_xmm(xmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int));
uni_vcvtdq2ps(xmm_idx(1), xmm_idx(1));
bubble_swap_xmm(xmm_val(0), xmm_idx(0), xmm_val(1), xmm_idx(1));
add(reg_i, 1);
@@ -1359,7 +1387,7 @@ private:
table_to_xmm(xmm_tmp, reg_bubble_seq_idx, reg_i, 0, sizeof(int));
get_addr_by_reg_idx(reg_tmp, reg_dst_idx, reg_aux, sizeof(int) / jcp_.data_size);
store_scalar(ptr[reg_tmp], xmm_tmp, memory::data_type::s32, false);
store_scalar(ptr[reg_tmp], xmm_tmp, memory::data_type::s32);
add(reg_sub_idx, 1);
cmp(reg_sub_idx, jcp_.blk_size);
@@ -1396,7 +1424,6 @@ private:
load_scalar(xmm_val_r, ptr[reg_aux], data_type);
table_to_xmm(xmm_idx_r, reg_bubble_seq_idx, reg_i, 0, sizeof(int));
uni_vcvtdq2ps(xmm_idx_r, xmm_idx_r);
sub(rsp, sizeof(int));
mov(ptr[rsp], reg_prc.cvt32());
@@ -1491,7 +1518,6 @@ private:
int offset = i / jcp_.blk_size * blk_stride + i % jcp_.blk_size;
load_scalar(xmm_val(i), ptr[reg_src + offset * jcp_.data_size], data_type);
uni_vmovdqu(xmm_idx(i), table_val(i));
uni_vcvtdq2ps(xmm_idx(i), xmm_idx(i));
}
// sort
for (int i = 0; i < jcp_.top_k - 1; i++) {
@@ -1503,7 +1529,6 @@ private:
int offset = i / jcp_.blk_size * blk_stride + i % jcp_.blk_size;
load_scalar(xmm_val(jcp_.top_k), ptr[reg_src + offset * jcp_.data_size], data_type);
uni_vmovdqu(xmm_idx(jcp_.top_k), table_val(i));
uni_vcvtdq2ps(xmm_idx(jcp_.top_k), xmm_idx(jcp_.top_k));
for (int j = jcp_.top_k; j > 0; j--) {
bubble_swap_xmm(xmm_val(j - 1), xmm_idx(j - 1), xmm_val(j), xmm_idx(j));
}
@@ -1536,7 +1561,7 @@ private:
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
load_i32_f32(reg_tmp, vmm_idx_l, elt_num);
load_i32(reg_tmp, vmm_idx_l, elt_num);
// load r
Xbyak::Label topk_load_jmp_label;
@@ -1552,7 +1577,7 @@ private:
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
load_i32_f32(reg_tmp, vmm_idx_r, elt_num);
load_i32(reg_tmp, vmm_idx_r, elt_num);
sub(reg_tmp_64, reg_block_sort_stride_byte);
}
@@ -1569,7 +1594,7 @@ private:
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
store_f32_i32(vmm_idx_l, reg_tmp, elt_num);
store_i32(vmm_idx_l, reg_tmp, elt_num);
// store r
Xbyak::Label topk_store_jmp_label;
@@ -1585,18 +1610,15 @@ private:
mov(reg_tmp, reg_tmp_64);
add(reg_tmp, reg_dst_idx);
reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size);
store_f32_i32(vmm_idx_r, reg_tmp, elt_num);
store_i32(vmm_idx_r, reg_tmp, elt_num);
}
L(topk_store_jmp_label);
}
inline void swap_vector(Vmm vmm_val_a, Vmm vmm_idx_a, Vmm vmm_val_b, Vmm vmm_idx_b, bool cmp_val = true) {
if (isa == cpu::x64::avx512_core) {
if (cmp_val)
vcmpps(k_mask, vmm_val_a, vmm_val_b, cmp_flg);
else
vcmpps(k_mask, vmm_idx_a, vmm_idx_b, _cmp_nle_us);
compare_node_xmm(vmm_val_a, vmm_idx_a, vmm_val_b, vmm_idx_b, vmm_mask, cmp_flg, _cmp_nle_us, cmp_val);
if (isa == cpu::x64::avx512_core) {
uni_vmovups(vmm_tmp, vmm_val_a);
vblendmps(vmm_val_a | k_mask, vmm_val_a, vmm_val_b);
vblendmps(vmm_val_b | k_mask, vmm_val_b, vmm_tmp);
@@ -1605,11 +1627,6 @@ private:
vblendmps(vmm_idx_a | k_mask, vmm_idx_a, vmm_idx_b);
vblendmps(vmm_idx_b | k_mask, vmm_idx_b, vmm_tmp);
} else {
if (cmp_val)
uni_vcmpps(vmm_mask, vmm_val_a, vmm_val_b, cmp_flg);
else
uni_vcmpps(vmm_mask, vmm_idx_a, vmm_idx_b, _cmp_nle_us);
uni_vmovups(vmm_tmp, vmm_val_a);
uni_vblendvps(vmm_val_a, vmm_val_a, vmm_val_b, vmm_mask);
uni_vblendvps(vmm_val_b, vmm_val_b, vmm_tmp, vmm_mask);
@@ -1675,12 +1692,9 @@ private:
}
inline void bubble_swap_xmm(Xmm xmm_val_a, Xmm xmm_idx_a, Xmm xmm_val_b, Xmm xmm_idx_b, bool cmp_val = true) {
if (isa == cpu::x64::avx512_core) {
if (cmp_val)
vcmpps(k_mask, xmm_val_a, xmm_val_b, cmp_flg);
else
vcmpps(k_mask, xmm_idx_a, xmm_idx_b, _cmp_nle_us);
compare_node_xmm(xmm_val_a, xmm_idx_a, xmm_val_b, xmm_idx_b, xmm_mask, cmp_flg, _cmp_nle_us, cmp_val);
if (isa == cpu::x64::avx512_core) {
uni_vmovups(xmm_tmp, xmm_val_a);
vblendmps(xmm_val_a | k_mask, xmm_val_a, xmm_val_b);
vblendmps(xmm_val_b | k_mask, xmm_val_b, xmm_tmp);
@@ -1689,11 +1703,6 @@ private:
vblendmps(xmm_idx_a | k_mask, xmm_idx_a, xmm_idx_b);
vblendmps(xmm_idx_b | k_mask, xmm_idx_b, xmm_tmp);
} else {
if (cmp_val)
uni_vcmpps(xmm_mask, xmm_val_a, xmm_val_b, cmp_flg);
else
uni_vcmpps(xmm_mask, xmm_idx_a, xmm_idx_b, _cmp_nle_us);
uni_vmovups(xmm_tmp, xmm_val_a);
uni_vblendvps(xmm_val_a, xmm_val_a, xmm_val_b, xmm_mask);
uni_vblendvps(xmm_val_b, xmm_val_b, xmm_tmp, xmm_mask);
@@ -1704,7 +1713,7 @@ private:
}
}
inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt, bool cvt_dt = true) {
inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, memory::data_type src_dt, bool cvt_dt = false) {
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
@@ -1731,7 +1740,7 @@ private:
}
}
inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt, bool cvt_dt = true) {
inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, memory::data_type dst_dt, bool cvt_dt = false) {
if (cvt_dt && !isFloatCompatible(dst_dt)) {
uni_vcvtps2dq(xmm_dst, xmm_dst);
}

View File

@@ -148,16 +148,25 @@ protected:
tensor = ov::test::utils::create_and_fill_tensor(funcInputs[0].get_element_type(), shape);
size_t size = tensor.get_size();
if (netPrecision == ElementType::f32) {
if (netPrecision == ElementType::f32 || netPrecision == ElementType::i32) {
std::vector<int> data(size);
int start = - static_cast<int>(size / 2);
// For int32, deliberately set big numbers which are not accurately representable in fp32
int start = netPrecision == ElementType::i32 ? pow(2, 30) + 1 : - static_cast<int>(size / 2);
std::iota(data.begin(), data.end(), start);
std::mt19937 gen(0);
std::shuffle(data.begin(), data.end(), gen);
auto *rawBlobDataPtr = static_cast<float *>(tensor.data());
for (size_t i = 0; i < size; ++i) {
rawBlobDataPtr[i] = static_cast<float>(data[i]);
if (netPrecision == ElementType::f32) {
auto *rawBlobDataPtr = static_cast<float *>(tensor.data());
for (size_t i = 0; i < size; ++i) {
rawBlobDataPtr[i] = static_cast<float>(data[i]);
}
} else {
auto *rawBlobDataPtr = static_cast<int32_t *>(tensor.data());
for (size_t i = 0; i < size; ++i) {
rawBlobDataPtr[i] = static_cast<int32_t>(data[i]);
}
}
} else if (netPrecision == ElementType::bf16) {
size_t O = 1, A = 1, I = 1;
@@ -283,6 +292,46 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_dynamic, TopKLayerCPUTest,
::testing::ValuesIn(additionalConfig)),
TopKLayerCPUTest::getTestCaseName);
const std::vector<int64_t> k_int32 = {1, 5, 7, 9};
std::vector<ov::test::InputShape> inputShapes_int32 = {
{{}, {{9, 9, 9, 9}}},
};
std::vector<ov::test::InputShape> inputShapesDynamic_int32 = {
{{9, {5, 10}, 9, {5, 10}}, {{9, 9, 9, 9}, {9, 10, 9, 10}}}
};
INSTANTIATE_TEST_CASE_P(smoke_TopK_int32, TopKLayerCPUTest,
::testing::Combine(
::testing::Combine(
::testing::ValuesIn(k_int32),
::testing::ValuesIn(axes),
::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes),
::testing::Values(ElementType::i32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::ValuesIn(inputShapes_int32)),
::testing::ValuesIn(filterCPUSpecificParams(cpuParams)),
::testing::Values(additionalConfig[0])),
TopKLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_TopK_int32_dynamic, TopKLayerCPUTest,
::testing::Combine(
::testing::Combine(
::testing::Values(1),
::testing::ValuesIn(axes),
::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes),
::testing::Values(ElementType::i32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::ValuesIn(inputShapesDynamic_int32)),
::testing::ValuesIn(filterCPUSpecificParams(cpuParams)),
::testing::Values(additionalConfig[0])),
TopKLayerCPUTest::getTestCaseName);
std::vector<ov::test::InputShape> inputShapes_bubble_BLK_on_channel_horiz = {
{{}, {{2, 2, 2, 2}}},
};