[CPU] Topk node extend to support horizontal sort for all layouts for top_k == 1 (#10700)

This commit is contained in:
Chen Xu
2022-03-16 19:09:14 +08:00
committed by GitHub
parent 4d8adabcaa
commit 49fb48e744

View File

@@ -226,10 +226,12 @@ private:
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
if (jcp_.top_k == 1) {
topk_bubble_BLK_on_channel_horiz();
topk_bubble_horiz();
} else {
topk_bubble_BLK_on_channel_verti();
}
} else if (jcp_.topk_innermost && jcp_.top_k == 1) {
topk_bubble_horiz();
} else {
topk_bubble_vector();
}
@@ -1178,7 +1180,7 @@ private:
}
}
inline void topk_bubble_BLK_on_channel_horiz() {
inline void topk_bubble_horiz() {
mov(reg_bubble_axis_dim, ptr[reg_params + GET_OFF(axis_dim)]);
mov(reg_seq_sort_stride, ptr[reg_params + GET_OFF(sort_stride)]);
mov(reg_bubble_seq_idx, ptr[reg_params + GET_OFF(idx_seq_buf)]);
@@ -2017,7 +2019,7 @@ void MKLDNNTopKNode::prepareParams() {
const size_t count_xmm = 16; // only 16 vector registers are valid in sse instructions even for avx512_common
if (top_k <= count_xmm / 2 - 2) {
algorithm = TopKAlgorithm::topk_bubble_sort;
bubble_inplace = layout == TopKLayoutType::topk_blocked && topk_innermost && top_k == 1 ? false : true;
bubble_inplace = topk_innermost && top_k == 1 ? false : true;
} else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
algorithm = TopKAlgorithm::topk_heap_sort;
} else {