[CPU] Fixed TopK node optimization leftovers (#9862)

This commit is contained in:
Chen Xu
2022-01-31 00:11:47 +08:00
committed by GitHub
parent aaf02ab4c7
commit cc5d30b26c
2 changed files with 98 additions and 72 deletions

View File

@@ -1842,6 +1842,7 @@ MKLDNNTopKNode::MKLDNNTopKNode(const std::shared_ptr<ngraph::Node>& op, const mk
sort_index = topKOp->get_sort_type() == ngraph::op::TopKSortType::SORT_INDICES;
top_k = 0;
preset_params_done = false;
vec_idx_seq.clear();
vec_idx_block.clear();
@@ -1884,7 +1885,6 @@ void MKLDNNTopKNode::initSupportedPrimitiveDescriptors() {
}
jit_mode = mayiuse(cpu::x64::sse41);
compile_kernel = jit_mode;
static const Precision supportedPrecision[] = {
Precision::FP32,
@@ -1935,6 +1935,39 @@ bool MKLDNNTopKNode::needPrepareParams() const {
return inputShapesModified() || top_k != src_k;
}
void MKLDNNTopKNode::preset_params() {
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) {
layout = TopKLayoutType::topk_ncsp;
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) {
layout = TopKLayoutType::topk_nspc;
} else {
layout = TopKLayoutType::topk_blocked;
}
auto selectedPD = getSelectedPrimitiveDescriptor();
auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision());
data_size = MKLDNNExtensionUtils::sizeOfDataType(data_type);
topk_innermost = (layout == TopKLayoutType::topk_ncsp && axis == static_cast<int>(getOutputShapeAtPort(TOPK_DATA).getRank() - 1)) ||
((layout == TopKLayoutType::topk_nspc || layout == TopKLayoutType::topk_blocked) && axis == 1);
if (mayiuse(cpu::x64::avx512_common)) {
blk_size = 16;
} else if (mayiuse(cpu::x64::sse41)) {
blk_size = 8;
}
if (isDynamicNode()) {
if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
algorithm = TopKAlgorithm::topk_heap_sort;
} else {
algorithm = TopKAlgorithm::topk_bubble_sort;
bubble_inplace = false;
}
}
}
void MKLDNNTopKNode::prepareParams() {
auto &dstMemPtr = getChildEdgeAt(TOPK_DATA)->getMemoryPtr();
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
@@ -1960,27 +1993,9 @@ void MKLDNNTopKNode::prepareParams() {
}
if (jit_mode) {
if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) {
layout = TopKLayoutType::topk_ncsp;
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) {
layout = TopKLayoutType::topk_nspc;
} else {
layout = TopKLayoutType::topk_blocked;
}
auto selectedPD = getSelectedPrimitiveDescriptor();
auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision());
data_size = MKLDNNExtensionUtils::sizeOfDataType(data_type);
topk_innermost = (layout == TopKLayoutType::topk_ncsp && axis == static_cast<int>(dst_dims.size() - 1)) ||
((layout == TopKLayoutType::topk_nspc || layout == TopKLayoutType::topk_blocked) && axis == 1);
if (mayiuse(cpu::x64::avx512_common)) {
blk_size = 16;
count_xmm = 16; // only 16 vector registers are valid in sse instructions
} else if (mayiuse(cpu::x64::sse41)) {
blk_size = 8;
count_xmm = 16;
if (!preset_params_done) {
preset_params();
preset_params_done = true;
}
auto layout_dims = dstMemPtr->GetDescWithType<BlockedMemoryDesc>()->getBlockDims();
@@ -1999,6 +2014,7 @@ void MKLDNNTopKNode::prepareParams() {
// the above two alg_costs are not the exact implementation costs, yet it's proper to use them to decide
// which algorithm should be used for specific N and K.
if (!isDynamicNode()) {
const size_t count_xmm = 16; // only 16 vector registers are valid in sse instructions even for avx512_common
if (top_k <= count_xmm / 2 - 2) {
algorithm = TopKAlgorithm::topk_bubble_sort;
bubble_inplace = layout == TopKLayoutType::topk_blocked && topk_innermost && top_k == 1 ? false : true;
@@ -2015,57 +2031,9 @@ void MKLDNNTopKNode::prepareParams() {
bubble_inplace = false;
}
}
} else {
if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
algorithm = TopKAlgorithm::topk_heap_sort;
} else {
algorithm = TopKAlgorithm::topk_bubble_sort;
bubble_inplace = false;
}
}
prepare_original_idx();
if (compile_kernel) {
auto jcp = jit_topk_config_params();
jcp.precision = selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision();
jcp.data_size = data_size;
jcp.blk_size = blk_size;
jcp.layout = layout;
jcp.top_k = top_k;
jcp.axis_dim = axis_dim;
jcp.mode_max = mode_max;
jcp.sort_index = sort_index;
jcp.topk_innermost = topk_innermost;
jcp.algorithm = algorithm;
jcp.bubble_inplace = bubble_inplace;
jcp.sort_stride = static_cast<int>(I);
jcp.work_amount = static_cast<int>(I);
if (algorithm == TopKAlgorithm::topk_bitonic_sort) {
size_t src_count = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
vec_process_ptr.resize(src_count * data_size);
vec_process_idx_ptr.resize(src_count * sizeof(int32_t));
calc_bitonic_idx(axis_dim, jcp.bitonic_idx_cnt, true);
if (sort_index) {
calc_bitonic_idx(top_k, jcp.bitonic_k_idx_cnt, false);
}
}
if (mayiuse(cpu::x64::avx512_common)) {
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx512_common>(jcp));
} else if (mayiuse(cpu::x64::avx2)) {
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx2>(jcp));
} else if (mayiuse(cpu::x64::sse41)) {
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::sse41>(jcp));
}
if (topk_kernel)
topk_kernel->create_ker();
compile_kernel = false;
}
} else { //reference mode
int j;
for (j = src_dims.size() - 1; j >= 0; j--) {
@@ -2077,6 +2045,63 @@ void MKLDNNTopKNode::prepareParams() {
}
}
void MKLDNNTopKNode::createPrimitive() {
if (inputShapesDefined() && isExecutable()) {
if (needPrepareParams())
prepareParams();
updateLastInputDims();
}
if (jit_mode) {
if (!preset_params_done) {
preset_params();
preset_params_done = true;
}
// Shape related config params will only be used for static shape sorting algorithms.
// Such params are useless for dynamic shapes, instead their jit_topk_call_args counterparts
// will be used. These params are: top_k, axis_dim, sort_stride, work_amount
auto jcp = jit_topk_config_params();
auto selectedPD = getSelectedPrimitiveDescriptor();
jcp.precision = selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision();
jcp.data_size = data_size;
jcp.blk_size = blk_size;
jcp.layout = layout;
jcp.top_k = top_k;
jcp.axis_dim = axis_dim;
jcp.mode_max = mode_max;
jcp.sort_index = sort_index;
jcp.topk_innermost = topk_innermost;
jcp.algorithm = algorithm;
jcp.bubble_inplace = bubble_inplace;
jcp.sort_stride = static_cast<int>(I);
jcp.work_amount = static_cast<int>(I);
if (algorithm == TopKAlgorithm::topk_bitonic_sort) {
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
size_t src_count = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
vec_process_ptr.resize(src_count * data_size);
vec_process_idx_ptr.resize(src_count * sizeof(int32_t));
calc_bitonic_idx(axis_dim, jcp.bitonic_idx_cnt, true);
if (sort_index) {
calc_bitonic_idx(top_k, jcp.bitonic_k_idx_cnt, false);
}
}
if (mayiuse(cpu::x64::avx512_common)) {
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx512_common>(jcp));
} else if (mayiuse(cpu::x64::avx2)) {
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx2>(jcp));
} else if (mayiuse(cpu::x64::sse41)) {
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::sse41>(jcp));
}
if (topk_kernel)
topk_kernel->create_ker();
}
}
void MKLDNNTopKNode::executeDynamicImpl(mkldnn::stream strm) {
execute(strm);
}

View File

@@ -85,6 +85,7 @@ public:
std::vector<VectorDims> shapeInfer() const override;
bool needPrepareParams() const override;
void prepareParams() override;
void createPrimitive() override;
bool created() const override;
void execute(mkldnn::stream strm) override;
void executeDynamicImpl(mkldnn::stream strm) override;
@@ -106,6 +107,7 @@ private:
void calc_dims_size(const InferenceEngine::SizeVector &layout_dims);
void topk_ref_process(const float* src_data, float* dst_data, int32_t* dst_idx,
const InferenceEngine::SizeVector &in_dims, std::function<float(float, float)> compare) const;
void preset_params();
void prepare_original_idx();
bool topk_innermost;
@@ -118,13 +120,12 @@ private:
static const size_t TOPK_INDEX = 1;
size_t O, A, I;
size_t blk_size;
size_t count_xmm;
size_t data_size;
size_t axis_dim;
int top_k;
int dim, before_num;
bool bubble_inplace;
bool compile_kernel;
bool preset_params_done;
InferenceEngine::SizeVector src_dims, dst_dims;
TopKLayoutType layout;