[CPU] Fixed TopK node optimization leftovers (#9862)
This commit is contained in:
@@ -1842,6 +1842,7 @@ MKLDNNTopKNode::MKLDNNTopKNode(const std::shared_ptr<ngraph::Node>& op, const mk
|
||||
sort_index = topKOp->get_sort_type() == ngraph::op::TopKSortType::SORT_INDICES;
|
||||
|
||||
top_k = 0;
|
||||
preset_params_done = false;
|
||||
vec_idx_seq.clear();
|
||||
vec_idx_block.clear();
|
||||
|
||||
@@ -1884,7 +1885,6 @@ void MKLDNNTopKNode::initSupportedPrimitiveDescriptors() {
|
||||
}
|
||||
|
||||
jit_mode = mayiuse(cpu::x64::sse41);
|
||||
compile_kernel = jit_mode;
|
||||
|
||||
static const Precision supportedPrecision[] = {
|
||||
Precision::FP32,
|
||||
@@ -1935,6 +1935,39 @@ bool MKLDNNTopKNode::needPrepareParams() const {
|
||||
return inputShapesModified() || top_k != src_k;
|
||||
}
|
||||
|
||||
void MKLDNNTopKNode::preset_params() {
|
||||
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
|
||||
if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) {
|
||||
layout = TopKLayoutType::topk_ncsp;
|
||||
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) {
|
||||
layout = TopKLayoutType::topk_nspc;
|
||||
} else {
|
||||
layout = TopKLayoutType::topk_blocked;
|
||||
}
|
||||
|
||||
auto selectedPD = getSelectedPrimitiveDescriptor();
|
||||
auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision());
|
||||
data_size = MKLDNNExtensionUtils::sizeOfDataType(data_type);
|
||||
|
||||
topk_innermost = (layout == TopKLayoutType::topk_ncsp && axis == static_cast<int>(getOutputShapeAtPort(TOPK_DATA).getRank() - 1)) ||
|
||||
((layout == TopKLayoutType::topk_nspc || layout == TopKLayoutType::topk_blocked) && axis == 1);
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
blk_size = 16;
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
blk_size = 8;
|
||||
}
|
||||
|
||||
if (isDynamicNode()) {
|
||||
if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
|
||||
algorithm = TopKAlgorithm::topk_heap_sort;
|
||||
} else {
|
||||
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||
bubble_inplace = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNTopKNode::prepareParams() {
|
||||
auto &dstMemPtr = getChildEdgeAt(TOPK_DATA)->getMemoryPtr();
|
||||
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
|
||||
@@ -1960,27 +1993,9 @@ void MKLDNNTopKNode::prepareParams() {
|
||||
}
|
||||
|
||||
if (jit_mode) {
|
||||
if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) {
|
||||
layout = TopKLayoutType::topk_ncsp;
|
||||
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) {
|
||||
layout = TopKLayoutType::topk_nspc;
|
||||
} else {
|
||||
layout = TopKLayoutType::topk_blocked;
|
||||
}
|
||||
|
||||
auto selectedPD = getSelectedPrimitiveDescriptor();
|
||||
auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision());
|
||||
data_size = MKLDNNExtensionUtils::sizeOfDataType(data_type);
|
||||
|
||||
topk_innermost = (layout == TopKLayoutType::topk_ncsp && axis == static_cast<int>(dst_dims.size() - 1)) ||
|
||||
((layout == TopKLayoutType::topk_nspc || layout == TopKLayoutType::topk_blocked) && axis == 1);
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
blk_size = 16;
|
||||
count_xmm = 16; // only 16 vector registers are valid in sse instructions
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
blk_size = 8;
|
||||
count_xmm = 16;
|
||||
if (!preset_params_done) {
|
||||
preset_params();
|
||||
preset_params_done = true;
|
||||
}
|
||||
|
||||
auto layout_dims = dstMemPtr->GetDescWithType<BlockedMemoryDesc>()->getBlockDims();
|
||||
@@ -1999,6 +2014,7 @@ void MKLDNNTopKNode::prepareParams() {
|
||||
// the above two alg_costs are not the exact implementation costs, yet it's proper to use them to decide
|
||||
// which algorithm should be used for specific N and K.
|
||||
if (!isDynamicNode()) {
|
||||
const size_t count_xmm = 16; // only 16 vector registers are valid in sse instructions even for avx512_common
|
||||
if (top_k <= count_xmm / 2 - 2) {
|
||||
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||
bubble_inplace = layout == TopKLayoutType::topk_blocked && topk_innermost && top_k == 1 ? false : true;
|
||||
@@ -2015,57 +2031,9 @@ void MKLDNNTopKNode::prepareParams() {
|
||||
bubble_inplace = false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
|
||||
algorithm = TopKAlgorithm::topk_heap_sort;
|
||||
} else {
|
||||
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||
bubble_inplace = false;
|
||||
}
|
||||
}
|
||||
|
||||
prepare_original_idx();
|
||||
|
||||
if (compile_kernel) {
|
||||
auto jcp = jit_topk_config_params();
|
||||
jcp.precision = selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision();
|
||||
jcp.data_size = data_size;
|
||||
jcp.blk_size = blk_size;
|
||||
jcp.layout = layout;
|
||||
jcp.top_k = top_k;
|
||||
jcp.axis_dim = axis_dim;
|
||||
jcp.mode_max = mode_max;
|
||||
jcp.sort_index = sort_index;
|
||||
jcp.topk_innermost = topk_innermost;
|
||||
jcp.algorithm = algorithm;
|
||||
jcp.bubble_inplace = bubble_inplace;
|
||||
jcp.sort_stride = static_cast<int>(I);
|
||||
jcp.work_amount = static_cast<int>(I);
|
||||
|
||||
if (algorithm == TopKAlgorithm::topk_bitonic_sort) {
|
||||
size_t src_count = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
|
||||
vec_process_ptr.resize(src_count * data_size);
|
||||
vec_process_idx_ptr.resize(src_count * sizeof(int32_t));
|
||||
|
||||
calc_bitonic_idx(axis_dim, jcp.bitonic_idx_cnt, true);
|
||||
if (sort_index) {
|
||||
calc_bitonic_idx(top_k, jcp.bitonic_k_idx_cnt, false);
|
||||
}
|
||||
}
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx512_common>(jcp));
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx2>(jcp));
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::sse41>(jcp));
|
||||
}
|
||||
|
||||
if (topk_kernel)
|
||||
topk_kernel->create_ker();
|
||||
|
||||
compile_kernel = false;
|
||||
}
|
||||
} else { //reference mode
|
||||
int j;
|
||||
for (j = src_dims.size() - 1; j >= 0; j--) {
|
||||
@@ -2077,6 +2045,63 @@ void MKLDNNTopKNode::prepareParams() {
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNTopKNode::createPrimitive() {
|
||||
if (inputShapesDefined() && isExecutable()) {
|
||||
if (needPrepareParams())
|
||||
prepareParams();
|
||||
updateLastInputDims();
|
||||
}
|
||||
|
||||
if (jit_mode) {
|
||||
if (!preset_params_done) {
|
||||
preset_params();
|
||||
preset_params_done = true;
|
||||
}
|
||||
|
||||
// Shape related config params will only be used for static shape sorting algorithms.
|
||||
// Such params are useless for dynamic shapes, instead their jit_topk_call_args counterparts
|
||||
// will be used. These params are: top_k, axis_dim, sort_stride, work_amount
|
||||
auto jcp = jit_topk_config_params();
|
||||
auto selectedPD = getSelectedPrimitiveDescriptor();
|
||||
jcp.precision = selectedPD->getConfig().inConfs[TOPK_DATA].desc->getPrecision();
|
||||
jcp.data_size = data_size;
|
||||
jcp.blk_size = blk_size;
|
||||
jcp.layout = layout;
|
||||
jcp.top_k = top_k;
|
||||
jcp.axis_dim = axis_dim;
|
||||
jcp.mode_max = mode_max;
|
||||
jcp.sort_index = sort_index;
|
||||
jcp.topk_innermost = topk_innermost;
|
||||
jcp.algorithm = algorithm;
|
||||
jcp.bubble_inplace = bubble_inplace;
|
||||
jcp.sort_stride = static_cast<int>(I);
|
||||
jcp.work_amount = static_cast<int>(I);
|
||||
|
||||
if (algorithm == TopKAlgorithm::topk_bitonic_sort) {
|
||||
auto &srcMemPtr = getParentEdgeAt(TOPK_DATA)->getMemoryPtr();
|
||||
size_t src_count = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
|
||||
vec_process_ptr.resize(src_count * data_size);
|
||||
vec_process_idx_ptr.resize(src_count * sizeof(int32_t));
|
||||
|
||||
calc_bitonic_idx(axis_dim, jcp.bitonic_idx_cnt, true);
|
||||
if (sort_index) {
|
||||
calc_bitonic_idx(top_k, jcp.bitonic_k_idx_cnt, false);
|
||||
}
|
||||
}
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx512_common>(jcp));
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::avx2>(jcp));
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
topk_kernel.reset(new jit_uni_topk_kernel_f32<cpu::x64::sse41>(jcp));
|
||||
}
|
||||
|
||||
if (topk_kernel)
|
||||
topk_kernel->create_ker();
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNTopKNode::executeDynamicImpl(mkldnn::stream strm) {
|
||||
execute(strm);
|
||||
}
|
||||
|
||||
@@ -85,6 +85,7 @@ public:
|
||||
std::vector<VectorDims> shapeInfer() const override;
|
||||
bool needPrepareParams() const override;
|
||||
void prepareParams() override;
|
||||
void createPrimitive() override;
|
||||
bool created() const override;
|
||||
void execute(mkldnn::stream strm) override;
|
||||
void executeDynamicImpl(mkldnn::stream strm) override;
|
||||
@@ -106,6 +107,7 @@ private:
|
||||
void calc_dims_size(const InferenceEngine::SizeVector &layout_dims);
|
||||
void topk_ref_process(const float* src_data, float* dst_data, int32_t* dst_idx,
|
||||
const InferenceEngine::SizeVector &in_dims, std::function<float(float, float)> compare) const;
|
||||
void preset_params();
|
||||
void prepare_original_idx();
|
||||
|
||||
bool topk_innermost;
|
||||
@@ -118,13 +120,12 @@ private:
|
||||
static const size_t TOPK_INDEX = 1;
|
||||
size_t O, A, I;
|
||||
size_t blk_size;
|
||||
size_t count_xmm;
|
||||
size_t data_size;
|
||||
size_t axis_dim;
|
||||
int top_k;
|
||||
int dim, before_num;
|
||||
bool bubble_inplace;
|
||||
bool compile_kernel;
|
||||
bool preset_params_done;
|
||||
|
||||
InferenceEngine::SizeVector src_dims, dst_dims;
|
||||
TopKLayoutType layout;
|
||||
|
||||
Reference in New Issue
Block a user