From cc5d30b26c3d4a07ff0915c61e2822beb324609d Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Mon, 31 Jan 2022 00:11:47 +0800 Subject: [PATCH] [CPU] Fixed TopK node optimization leftovers (#9862) --- .../intel_cpu/src/nodes/mkldnn_topk_node.cpp | 165 ++++++++++-------- .../intel_cpu/src/nodes/mkldnn_topk_node.h | 5 +- 2 files changed, 98 insertions(+), 72 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.cpp b/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.cpp index 09ee24531a3..ce77c885ac2 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.cpp +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.cpp @@ -1842,6 +1842,7 @@ MKLDNNTopKNode::MKLDNNTopKNode(const std::shared_ptr& op, const mk sort_index = topKOp->get_sort_type() == ngraph::op::TopKSortType::SORT_INDICES; top_k = 0; + preset_params_done = false; vec_idx_seq.clear(); vec_idx_block.clear(); @@ -1884,7 +1885,6 @@ void MKLDNNTopKNode::initSupportedPrimitiveDescriptors() { } jit_mode = mayiuse(cpu::x64::sse41); - compile_kernel = jit_mode; static const Precision supportedPrecision[] = { Precision::FP32, @@ -1935,6 +1935,39 @@ bool MKLDNNTopKNode::needPrepareParams() const { return inputShapesModified() || top_k != src_k; } +void MKLDNNTopKNode::preset_params() { + auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr(); + if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) { + layout = TopKLayoutType::topk_ncsp; + } else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) { + layout = TopKLayoutType::topk_nspc; + } else { + layout = TopKLayoutType::topk_blocked; + } + + auto selectedPD = getSelectedPrimitiveDescriptor(); + auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision()); + data_size = MKLDNNExtensionUtils::sizeOfDataType(data_type); + + topk_innermost = (layout == TopKLayoutType::topk_ncsp && axis == static_cast(getOutputShapeAtPort(TOPK_DATA).getRank() - 1)) || + ((layout == TopKLayoutType::topk_nspc || layout == TopKLayoutType::topk_blocked) && axis == 1); + + if (mayiuse(cpu::x64::avx512_common)) { + blk_size = 16; + } else if (mayiuse(cpu::x64::sse41)) { + blk_size = 8; + } + + if (isDynamicNode()) { + if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) { + algorithm = TopKAlgorithm::topk_heap_sort; + } else { + algorithm = TopKAlgorithm::topk_bubble_sort; + bubble_inplace = false; + } + } +} + void MKLDNNTopKNode::prepareParams() { auto &dstMemPtr = getChildEdgeAt(TOPK_DATA)->getMemoryPtr(); auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr(); @@ -1960,27 +1993,9 @@ void MKLDNNTopKNode::prepareParams() { } if (jit_mode) { - if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) { - layout = TopKLayoutType::topk_ncsp; - } else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) { - layout = TopKLayoutType::topk_nspc; - } else { - layout = TopKLayoutType::topk_blocked; - } - - auto selectedPD = getSelectedPrimitiveDescriptor(); - auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision()); - data_size = MKLDNNExtensionUtils::sizeOfDataType(data_type); - - topk_innermost = (layout == TopKLayoutType::topk_ncsp && axis == static_cast(dst_dims.size() - 1)) || - ((layout == TopKLayoutType::topk_nspc || layout == TopKLayoutType::topk_blocked) && axis == 1); - - if (mayiuse(cpu::x64::avx512_common)) { - blk_size = 16; - count_xmm = 16; // only 16 vector registers are valid in sse instructions - } else if (mayiuse(cpu::x64::sse41)) { - blk_size = 8; - count_xmm = 16; + if (!preset_params_done) { + preset_params(); + preset_params_done = true; } auto layout_dims = dstMemPtr->GetDescWithType()->getBlockDims(); @@ -1999,6 +2014,7 @@ void MKLDNNTopKNode::prepareParams() { // the above two alg_costs are not the exact implementation costs, yet it's proper to use them to decide // which algorithm should be used for specific N and K. if (!isDynamicNode()) { + const size_t count_xmm = 16; // only 16 vector registers are valid in sse instructions even for avx512_common if (top_k <= count_xmm / 2 - 2) { algorithm = TopKAlgorithm::topk_bubble_sort; bubble_inplace = layout == TopKLayoutType::topk_blocked && topk_innermost && top_k == 1 ? false : true; @@ -2015,57 +2031,9 @@ void MKLDNNTopKNode::prepareParams() { bubble_inplace = false; } } - } else { - if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) { - algorithm = TopKAlgorithm::topk_heap_sort; - } else { - algorithm = TopKAlgorithm::topk_bubble_sort; - bubble_inplace = false; - } } prepare_original_idx(); - - if (compile_kernel) { - auto jcp = jit_topk_config_params(); - jcp.precision = selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision(); - jcp.data_size = data_size; - jcp.blk_size = blk_size; - jcp.layout = layout; - jcp.top_k = top_k; - jcp.axis_dim = axis_dim; - jcp.mode_max = mode_max; - jcp.sort_index = sort_index; - jcp.topk_innermost = topk_innermost; - jcp.algorithm = algorithm; - jcp.bubble_inplace = bubble_inplace; - jcp.sort_stride = static_cast(I); - jcp.work_amount = static_cast(I); - - if (algorithm == TopKAlgorithm::topk_bitonic_sort) { - size_t src_count = srcMemPtr->GetDescWithType()->getPaddedElementsCount(); - vec_process_ptr.resize(src_count * data_size); - vec_process_idx_ptr.resize(src_count * sizeof(int32_t)); - - calc_bitonic_idx(axis_dim, jcp.bitonic_idx_cnt, true); - if (sort_index) { - calc_bitonic_idx(top_k, jcp.bitonic_k_idx_cnt, false); - } - } - - if (mayiuse(cpu::x64::avx512_common)) { - topk_kernel.reset(new jit_uni_topk_kernel_f32(jcp)); - } else if (mayiuse(cpu::x64::avx2)) { - topk_kernel.reset(new jit_uni_topk_kernel_f32(jcp)); - } else if (mayiuse(cpu::x64::sse41)) { - topk_kernel.reset(new jit_uni_topk_kernel_f32(jcp)); - } - - if (topk_kernel) - topk_kernel->create_ker(); - - compile_kernel = false; - } } else { //reference mode int j; for (j = src_dims.size() - 1; j >= 0; j--) { @@ -2077,6 +2045,63 @@ void MKLDNNTopKNode::prepareParams() { } } +void MKLDNNTopKNode::createPrimitive() { + if (inputShapesDefined() && isExecutable()) { + if (needPrepareParams()) + prepareParams(); + updateLastInputDims(); + } + + if (jit_mode) { + if (!preset_params_done) { + preset_params(); + preset_params_done = true; + } + + // Shape related config params will only be used for static shape sorting algorithms. + // Such params are useless for dynamic shapes, instead their jit_topk_call_args counterparts + // will be used. These params are: top_k, axis_dim, sort_stride, work_amount + auto jcp = jit_topk_config_params(); + auto selectedPD = getSelectedPrimitiveDescriptor(); + jcp.precision = selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision(); + jcp.data_size = data_size; + jcp.blk_size = blk_size; + jcp.layout = layout; + jcp.top_k = top_k; + jcp.axis_dim = axis_dim; + jcp.mode_max = mode_max; + jcp.sort_index = sort_index; + jcp.topk_innermost = topk_innermost; + jcp.algorithm = algorithm; + jcp.bubble_inplace = bubble_inplace; + jcp.sort_stride = static_cast(I); + jcp.work_amount = static_cast(I); + + if (algorithm == TopKAlgorithm::topk_bitonic_sort) { + auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr(); + size_t src_count = srcMemPtr->GetDescWithType()->getPaddedElementsCount(); + vec_process_ptr.resize(src_count * data_size); + vec_process_idx_ptr.resize(src_count * sizeof(int32_t)); + + calc_bitonic_idx(axis_dim, jcp.bitonic_idx_cnt, true); + if (sort_index) { + calc_bitonic_idx(top_k, jcp.bitonic_k_idx_cnt, false); + } + } + + if (mayiuse(cpu::x64::avx512_common)) { + topk_kernel.reset(new jit_uni_topk_kernel_f32(jcp)); + } else if (mayiuse(cpu::x64::avx2)) { + topk_kernel.reset(new jit_uni_topk_kernel_f32(jcp)); + } else if (mayiuse(cpu::x64::sse41)) { + topk_kernel.reset(new jit_uni_topk_kernel_f32(jcp)); + } + + if (topk_kernel) + topk_kernel->create_ker(); + } +} + void MKLDNNTopKNode::executeDynamicImpl(mkldnn::stream strm) { execute(strm); } diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.h b/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.h index f74cb3a5811..16261f8c951 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.h +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_topk_node.h @@ -85,6 +85,7 @@ public: std::vector shapeInfer() const override; bool needPrepareParams() const override; void prepareParams() override; + void createPrimitive() override; bool created() const override; void execute(mkldnn::stream strm) override; void executeDynamicImpl(mkldnn::stream strm) override; @@ -106,6 +107,7 @@ private: void calc_dims_size(const InferenceEngine::SizeVector &layout_dims); void topk_ref_process(const float* src_data, float* dst_data, int32_t* dst_idx, const InferenceEngine::SizeVector &in_dims, std::function compare) const; + void preset_params(); void prepare_original_idx(); bool topk_innermost; @@ -118,13 +120,12 @@ private: static const size_t TOPK_INDEX = 1; size_t O, A, I; size_t blk_size; - size_t count_xmm; size_t data_size; size_t axis_dim; int top_k; int dim, before_num; bool bubble_inplace; - bool compile_kernel; + bool preset_params_done; InferenceEngine::SizeVector src_dims, dst_dims; TopKLayoutType layout;