[CPU] Topk node extend to support horizontal sort for all layouts for top_k == 1 (#10700)
This commit is contained in:
@@ -226,10 +226,12 @@ private:
|
||||
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
|
||||
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
|
||||
if (jcp_.top_k == 1) {
|
||||
topk_bubble_BLK_on_channel_horiz();
|
||||
topk_bubble_horiz();
|
||||
} else {
|
||||
topk_bubble_BLK_on_channel_verti();
|
||||
}
|
||||
} else if (jcp_.topk_innermost && jcp_.top_k == 1) {
|
||||
topk_bubble_horiz();
|
||||
} else {
|
||||
topk_bubble_vector();
|
||||
}
|
||||
@@ -1178,7 +1180,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
inline void topk_bubble_BLK_on_channel_horiz() {
|
||||
inline void topk_bubble_horiz() {
|
||||
mov(reg_bubble_axis_dim, ptr[reg_params + GET_OFF(axis_dim)]);
|
||||
mov(reg_seq_sort_stride, ptr[reg_params + GET_OFF(sort_stride)]);
|
||||
mov(reg_bubble_seq_idx, ptr[reg_params + GET_OFF(idx_seq_buf)]);
|
||||
@@ -2017,7 +2019,7 @@ void MKLDNNTopKNode::prepareParams() {
|
||||
const size_t count_xmm = 16; // only 16 vector registers are valid in sse instructions even for avx512_common
|
||||
if (top_k <= count_xmm / 2 - 2) {
|
||||
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||
bubble_inplace = layout == TopKLayoutType::topk_blocked && topk_innermost && top_k == 1 ? false : true;
|
||||
bubble_inplace = topk_innermost && top_k == 1 ? false : true;
|
||||
} else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
|
||||
algorithm = TopKAlgorithm::topk_heap_sort;
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user