[CPU] Improve performance in 5D scenario of Reduce node (#17828)

This commit is contained in:
Chen Xu 2023-06-05 18:52:18 +08:00 committed by GitHub
parent 43bf90f90c
commit 1dafb405fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 37 deletions

View File

@ -1761,7 +1761,9 @@ Reduce::Reduce(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
IE_THROW() << errorPrefix << " second tensor is not constant!";
raw_axes = reduceConst->cast_vector<int>();
}
set_use_aux_kernel = false;
vec_reduceDH_prc.clear();
vec_reduceCDW_prc.clear();
setJITBeyond5D();
} else {
IE_THROW(NotImplemented) << errorMessage;
@ -2044,6 +2046,14 @@ void Reduce::createPrimitive() {
create_reduce_kernel(reduce_kernel, jcp);
// set_use_aux_kernel being false means this is a dynamic case, and prepareParams() hasn't been invoked yet.
// So set use_aux_kernel true if precision changes, in case ReduceDH_opt, ReduceCDW_opt or ReduceAll_opt
// should be true when invoking prepareParams(), then aux kernel will be needed.
if (!set_use_aux_kernel) {
use_aux_kernel = precision_change;
set_use_aux_kernel = true;
}
// For scenarios(e.g. when ReduceDH_opt or ReduceAll_opt is true) that apply two stages of kernel invocation
// to improve parallelism, if the precision is asymmetrical, we apply the aux kernel on the second stage. For
// example, if the original kernel is bf16-in-fp32-out, then this original kernel will be applied on first
@ -2197,14 +2207,38 @@ void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
}
}
} else if (!ReduceH && ReduceW) {
for (size_t ic = 0; ic < IC; ic++) {
size_t oc = ReduceC ? 0 : ic; GET_PTR_NC_PLN;
for (size_t id = 0; id < ID; id++) {
size_t od = ReduceD ? 0 : id; GET_PTR_NCD_PLN;
parallel_for(IH, [&](size_t ih){
size_t oh = ih; GET_PTR_NCDH_PLN;
reduce_kernel_process(in_ptr_ncdh, out_ptr_ncdh, IW, 1);
});
if (ReduceCDW_opt) {
// reduce parallelly in HW dimensions
// step1: ReduceC && ReduceD && !ReduceH && !ReduceW
uint8_t *prc_ptr_n = &vec_reduceCDW_prc[0];
init_dst_data(prc_ptr_n, prc_size);
size_t IS = IH * IW;
reduce_stride = IS;
parallel_for(IS / blk_size, [&](size_t ibs){
size_t pbs = ibs;
reduce_kernel_process(in_ptr_n + ibs * blk_size * src_data_size, prc_ptr_n + pbs * blk_size * prc_data_size,
blk_size, 0, IC * ID);
});
size_t tail_start = IS / blk_size * blk_size;
reduce_kernel_process(in_ptr_n + tail_start * src_data_size, prc_ptr_n + tail_start * prc_data_size,
IS - tail_start, 0, IC * ID);
// step2: ReduceW
reduce_kernel_reassign();
parallel_for(PH, [&](size_t ph){
size_t oh = ph;
reduce_kernel_process(prc_ptr_n + ph * PW * prc_data_size, out_ptr_n + oh * OW * dst_data_size, IW, 1);
});
reduce_kernel_restore();
} else {
for (size_t ic = 0; ic < IC; ic++) {
size_t oc = ReduceC ? 0 : ic; GET_PTR_NC_PLN;
for (size_t id = 0; id < ID; id++) {
size_t od = ReduceD ? 0 : id; GET_PTR_NCD_PLN;
parallel_for(IH, [&](size_t ih){
size_t oh = ih; GET_PTR_NCDH_PLN;
reduce_kernel_process(in_ptr_ncdh, out_ptr_ncdh, IW, 1);
});
}
}
}
} else if (ReduceW) {
@ -2245,18 +2279,13 @@ void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
});
// step2: ReduceD
reduce_stride = PW;
if (use_aux_kernel) {
reduce_tmp_kernel = reduce_kernel;
reduce_kernel = reduce_aux_kernel;
}
reduce_kernel_reassign();
parallel_for(IWB, [&](size_t iwb){
size_t pwb = iwb, owb = iwb;
reduce_kernel_process(prc_ptr_n + pwb * blk_size * prc_data_size,
out_ptr_n + owb * blk_size * dst_data_size, blk_size, 0, ID);
});
if (use_aux_kernel) {
reduce_kernel = reduce_tmp_kernel;
}
reduce_kernel_restore();
}
// reduce tail
reduce_stride = IW;
@ -2353,14 +2382,9 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
reduce_kernel_process(in_ptr_nc, out_ptr_nc, ID * IH * IW * blk_size);
});
// step2: ReduceC
if (use_aux_kernel) {
reduce_tmp_kernel = reduce_kernel;
reduce_kernel = reduce_aux_kernel;
}
reduce_kernel_reassign();
reduce_kernel_process(out_ptr_n, out_ptr_n_cp, ICB * blk_size);
if (use_aux_kernel) {
reduce_kernel = reduce_tmp_kernel;
}
reduce_kernel_restore();
}
} else if (ReduceW) {
for (size_t icb = 0; icb < ICB; icb++) {
@ -2531,6 +2555,18 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
}
}
inline void Reduce::reduce_kernel_reassign() {
if (use_aux_kernel) {
reduce_tmp_kernel = reduce_kernel;
reduce_kernel = reduce_aux_kernel;
}
}
inline void Reduce::reduce_kernel_restore() {
if (use_aux_kernel) {
reduce_kernel = reduce_tmp_kernel;
}
}
void Reduce::nspc2ncsp(uint8_t *proc_ptr, uint8_t *out_ptr) {
// dimension reinterpret after nspc reusing routine reduce_PLN
// demote -- nspc -- ncsp
@ -2734,7 +2770,7 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
}
}
inline void Reduce::create_working_memory() {
inline void Reduce::create_hybrid_working_memory() {
auto rank = getInputShapeAtPort(REDUCE_DATA).getRank();
memory::format_tag format = (layout == ReduceLayoutType::reduce_nspc) ? (rank == 4 ? memory::format_tag::nhwc : memory::format_tag::ndhwc)
: (rank == 4 ? (mayiuse(cpu::x64::avx512_core) ? memory::format_tag::nChw16c : memory::format_tag::nChw8c)
@ -2745,17 +2781,30 @@ inline void Reduce::create_working_memory() {
dst_size = desc.get_size();
}
inline void Reduce::create_DH_working_memory() {
ReduceDH_opt = layout == ReduceLayoutType::reduce_nspc && !isDynamicNode() && support_split &&
inline void Reduce::create_opt_working_memory() {
ReduceDH_opt = layout == ReduceLayoutType::reduce_nspc && support_split &&
!ReduceC && ReduceD && ReduceH && !ReduceW && IC == 1 && ID > 1;
if (ReduceDH_opt) {
PD = ID;
PW = IW / blk_size * blk_size;
prc_data_size = dst_data_size;
prc_size = PD * PW * dst_data_size;
prc_size = PD * PW * prc_data_size;
if (prc_size > vec_reduceDH_prc.size()) {
vec_reduceDH_prc.resize(prc_size);
}
return;
}
ReduceCDW_opt = layout == ReduceLayoutType::reduce_ncsp && support_split &&
ReduceC && ReduceD && !ReduceH && ReduceW;
if (ReduceCDW_opt) {
PH = IH;
PW = IW;
prc_data_size = dst_data_size;
prc_size = PH * PW * prc_data_size;
if (prc_size > vec_reduceCDW_prc.size()) {
vec_reduceCDW_prc.resize(prc_size);
}
}
}
@ -2821,7 +2870,7 @@ inline void Reduce::set_reduce_dim_flags() {
// must be done before the following dimension change
if (is_hybrid_layout) {
create_working_memory();
create_hybrid_working_memory();
}
// Reducing a dimesion in nspc layout can be treated as reducing another dimension in ncsp layout,
@ -2842,13 +2891,6 @@ inline void Reduce::set_reduce_dim_flags() {
ReduceH = IH != OH && OH == 1;
ReduceW = IW != OW && OW == 1;
// must be done after the above dimension change
create_DH_working_memory();
ReduceAll_opt = layout == ReduceLayoutType::reduce_blocked && !isDynamicNode() && support_split &&
ReduceC && ReduceD && ReduceH && ReduceW;
use_aux_kernel = (ReduceDH_opt || ReduceAll_opt) && precision_change;
// suit for parallel
if (ReduceH && IW == 1) {
ReduceW = true;
@ -2856,6 +2898,16 @@ inline void Reduce::set_reduce_dim_flags() {
if (ReduceC && ReduceH && ID == 1) {
ReduceD = true;
}
// must be done after the above dimension change
create_opt_working_memory();
ReduceAll_opt = layout == ReduceLayoutType::reduce_blocked && support_split &&
ReduceC && ReduceD && ReduceH && ReduceW;
if (!set_use_aux_kernel) {
use_aux_kernel = (ReduceDH_opt || ReduceCDW_opt || ReduceAll_opt) && precision_change;
set_use_aux_kernel = true;
}
}
inline void Reduce::reduce_ref(const float *in_ptr, float *out_ptr) {

View File

@ -112,9 +112,11 @@ private:
inline void reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, size_t work_amount,
size_t reduce_w = 2, size_t work_batch = 1, const int *tab_idx = NULL);
inline void reduce_kernel_post_process(uint8_t *out_ptr);
inline void reduce_kernel_reassign();
inline void reduce_kernel_restore();
inline void init_dst_data(uint8_t *out_ptr, size_t dst_size);
inline void create_working_memory();
inline void create_DH_working_memory();
inline void create_hybrid_working_memory();
inline void create_opt_working_memory();
inline void calc_process_dst_dims(std::vector<int> &reduce_axes, const InferenceEngine::SizeVector &dst_dim);
inline void set_reduce_dim_flags();
inline void reduce_ref(const float *in_ptr, float *out_ptr);
@ -142,11 +144,13 @@ private:
bool precision_change = false;
bool ReduceAll_opt = false;
bool ReduceDH_opt = false;
bool ReduceCDW_opt = false;
bool use_aux_kernel = false;
bool set_use_aux_kernel = false;
bool ReduceN, ReduceC, ReduceD, ReduceH, ReduceW;
size_t IB, IC, ID, IH, IW;
size_t OB, OC, OD, OH, OW;
size_t PD, PW;
size_t PD, PH, PW;
size_t src_data_size, dst_data_size, prc_data_size;
size_t reduce_stride;
ReduceLayoutType layout;
@ -165,6 +169,7 @@ private:
dnnl::memory prc_mem;
std::vector<uint8_t> vec_reduceDH_prc;
std::vector<uint8_t> vec_reduceCDW_prc;
std::shared_ptr<jit_uni_reduce_kernel> reduce_kernel;
std::shared_ptr<jit_uni_reduce_kernel> reduce_aux_kernel;

View File

@ -215,6 +215,7 @@ const std::vector<std::vector<int>> axesND = {
const std::vector<std::vector<int>> axes5D = {
{2, 4},
{0, 2, 4},
{1, 2, 4},
{0, 1, 2, 3, 4},
};