[CPU] Improve performance in 5D scenario of Reduce node (#17828)
This commit is contained in:
parent
43bf90f90c
commit
1dafb405fd
@ -1761,7 +1761,9 @@ Reduce::Reduce(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
|
||||
IE_THROW() << errorPrefix << " second tensor is not constant!";
|
||||
raw_axes = reduceConst->cast_vector<int>();
|
||||
}
|
||||
set_use_aux_kernel = false;
|
||||
vec_reduceDH_prc.clear();
|
||||
vec_reduceCDW_prc.clear();
|
||||
setJITBeyond5D();
|
||||
} else {
|
||||
IE_THROW(NotImplemented) << errorMessage;
|
||||
@ -2044,6 +2046,14 @@ void Reduce::createPrimitive() {
|
||||
|
||||
create_reduce_kernel(reduce_kernel, jcp);
|
||||
|
||||
// set_use_aux_kernel being false means this is a dynamic case, and prepareParams() hasn't been invoked yet.
|
||||
// So set use_aux_kernel true if precision changes, in case ReduceDH_opt, ReduceCDW_opt or ReduceAll_opt
|
||||
// should be true when invoking prepareParams(), then aux kernel will be needed.
|
||||
if (!set_use_aux_kernel) {
|
||||
use_aux_kernel = precision_change;
|
||||
set_use_aux_kernel = true;
|
||||
}
|
||||
|
||||
// For scenarios(e.g. when ReduceDH_opt or ReduceAll_opt is true) that apply two stages of kernel invocation
|
||||
// to improve parallelism, if the precision is asymmetrical, we apply the aux kernel on the second stage. For
|
||||
// example, if the original kernel is bf16-in-fp32-out, then this original kernel will be applied on first
|
||||
@ -2197,14 +2207,38 @@ void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
}
|
||||
}
|
||||
} else if (!ReduceH && ReduceW) {
|
||||
for (size_t ic = 0; ic < IC; ic++) {
|
||||
size_t oc = ReduceC ? 0 : ic; GET_PTR_NC_PLN;
|
||||
for (size_t id = 0; id < ID; id++) {
|
||||
size_t od = ReduceD ? 0 : id; GET_PTR_NCD_PLN;
|
||||
parallel_for(IH, [&](size_t ih){
|
||||
size_t oh = ih; GET_PTR_NCDH_PLN;
|
||||
reduce_kernel_process(in_ptr_ncdh, out_ptr_ncdh, IW, 1);
|
||||
});
|
||||
if (ReduceCDW_opt) {
|
||||
// reduce parallelly in HW dimensions
|
||||
// step1: ReduceC && ReduceD && !ReduceH && !ReduceW
|
||||
uint8_t *prc_ptr_n = &vec_reduceCDW_prc[0];
|
||||
init_dst_data(prc_ptr_n, prc_size);
|
||||
size_t IS = IH * IW;
|
||||
reduce_stride = IS;
|
||||
parallel_for(IS / blk_size, [&](size_t ibs){
|
||||
size_t pbs = ibs;
|
||||
reduce_kernel_process(in_ptr_n + ibs * blk_size * src_data_size, prc_ptr_n + pbs * blk_size * prc_data_size,
|
||||
blk_size, 0, IC * ID);
|
||||
});
|
||||
size_t tail_start = IS / blk_size * blk_size;
|
||||
reduce_kernel_process(in_ptr_n + tail_start * src_data_size, prc_ptr_n + tail_start * prc_data_size,
|
||||
IS - tail_start, 0, IC * ID);
|
||||
// step2: ReduceW
|
||||
reduce_kernel_reassign();
|
||||
parallel_for(PH, [&](size_t ph){
|
||||
size_t oh = ph;
|
||||
reduce_kernel_process(prc_ptr_n + ph * PW * prc_data_size, out_ptr_n + oh * OW * dst_data_size, IW, 1);
|
||||
});
|
||||
reduce_kernel_restore();
|
||||
} else {
|
||||
for (size_t ic = 0; ic < IC; ic++) {
|
||||
size_t oc = ReduceC ? 0 : ic; GET_PTR_NC_PLN;
|
||||
for (size_t id = 0; id < ID; id++) {
|
||||
size_t od = ReduceD ? 0 : id; GET_PTR_NCD_PLN;
|
||||
parallel_for(IH, [&](size_t ih){
|
||||
size_t oh = ih; GET_PTR_NCDH_PLN;
|
||||
reduce_kernel_process(in_ptr_ncdh, out_ptr_ncdh, IW, 1);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (ReduceW) {
|
||||
@ -2245,18 +2279,13 @@ void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
});
|
||||
// step2: ReduceD
|
||||
reduce_stride = PW;
|
||||
if (use_aux_kernel) {
|
||||
reduce_tmp_kernel = reduce_kernel;
|
||||
reduce_kernel = reduce_aux_kernel;
|
||||
}
|
||||
reduce_kernel_reassign();
|
||||
parallel_for(IWB, [&](size_t iwb){
|
||||
size_t pwb = iwb, owb = iwb;
|
||||
reduce_kernel_process(prc_ptr_n + pwb * blk_size * prc_data_size,
|
||||
out_ptr_n + owb * blk_size * dst_data_size, blk_size, 0, ID);
|
||||
});
|
||||
if (use_aux_kernel) {
|
||||
reduce_kernel = reduce_tmp_kernel;
|
||||
}
|
||||
reduce_kernel_restore();
|
||||
}
|
||||
// reduce tail
|
||||
reduce_stride = IW;
|
||||
@ -2353,14 +2382,9 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
reduce_kernel_process(in_ptr_nc, out_ptr_nc, ID * IH * IW * blk_size);
|
||||
});
|
||||
// step2: ReduceC
|
||||
if (use_aux_kernel) {
|
||||
reduce_tmp_kernel = reduce_kernel;
|
||||
reduce_kernel = reduce_aux_kernel;
|
||||
}
|
||||
reduce_kernel_reassign();
|
||||
reduce_kernel_process(out_ptr_n, out_ptr_n_cp, ICB * blk_size);
|
||||
if (use_aux_kernel) {
|
||||
reduce_kernel = reduce_tmp_kernel;
|
||||
}
|
||||
reduce_kernel_restore();
|
||||
}
|
||||
} else if (ReduceW) {
|
||||
for (size_t icb = 0; icb < ICB; icb++) {
|
||||
@ -2531,6 +2555,18 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void Reduce::reduce_kernel_reassign() {
|
||||
if (use_aux_kernel) {
|
||||
reduce_tmp_kernel = reduce_kernel;
|
||||
reduce_kernel = reduce_aux_kernel;
|
||||
}
|
||||
}
|
||||
inline void Reduce::reduce_kernel_restore() {
|
||||
if (use_aux_kernel) {
|
||||
reduce_kernel = reduce_tmp_kernel;
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::nspc2ncsp(uint8_t *proc_ptr, uint8_t *out_ptr) {
|
||||
// dimension reinterpret after nspc reusing routine reduce_PLN
|
||||
// demote -- nspc -- ncsp
|
||||
@ -2734,7 +2770,7 @@ inline void Reduce::init_dst_data(uint8_t *out_ptr, size_t dst_size) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void Reduce::create_working_memory() {
|
||||
inline void Reduce::create_hybrid_working_memory() {
|
||||
auto rank = getInputShapeAtPort(REDUCE_DATA).getRank();
|
||||
memory::format_tag format = (layout == ReduceLayoutType::reduce_nspc) ? (rank == 4 ? memory::format_tag::nhwc : memory::format_tag::ndhwc)
|
||||
: (rank == 4 ? (mayiuse(cpu::x64::avx512_core) ? memory::format_tag::nChw16c : memory::format_tag::nChw8c)
|
||||
@ -2745,17 +2781,30 @@ inline void Reduce::create_working_memory() {
|
||||
dst_size = desc.get_size();
|
||||
}
|
||||
|
||||
inline void Reduce::create_DH_working_memory() {
|
||||
ReduceDH_opt = layout == ReduceLayoutType::reduce_nspc && !isDynamicNode() && support_split &&
|
||||
inline void Reduce::create_opt_working_memory() {
|
||||
ReduceDH_opt = layout == ReduceLayoutType::reduce_nspc && support_split &&
|
||||
!ReduceC && ReduceD && ReduceH && !ReduceW && IC == 1 && ID > 1;
|
||||
if (ReduceDH_opt) {
|
||||
PD = ID;
|
||||
PW = IW / blk_size * blk_size;
|
||||
prc_data_size = dst_data_size;
|
||||
prc_size = PD * PW * dst_data_size;
|
||||
prc_size = PD * PW * prc_data_size;
|
||||
if (prc_size > vec_reduceDH_prc.size()) {
|
||||
vec_reduceDH_prc.resize(prc_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
ReduceCDW_opt = layout == ReduceLayoutType::reduce_ncsp && support_split &&
|
||||
ReduceC && ReduceD && !ReduceH && ReduceW;
|
||||
if (ReduceCDW_opt) {
|
||||
PH = IH;
|
||||
PW = IW;
|
||||
prc_data_size = dst_data_size;
|
||||
prc_size = PH * PW * prc_data_size;
|
||||
if (prc_size > vec_reduceCDW_prc.size()) {
|
||||
vec_reduceCDW_prc.resize(prc_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2821,7 +2870,7 @@ inline void Reduce::set_reduce_dim_flags() {
|
||||
|
||||
// must be done before the following dimension change
|
||||
if (is_hybrid_layout) {
|
||||
create_working_memory();
|
||||
create_hybrid_working_memory();
|
||||
}
|
||||
|
||||
// Reducing a dimesion in nspc layout can be treated as reducing another dimension in ncsp layout,
|
||||
@ -2842,13 +2891,6 @@ inline void Reduce::set_reduce_dim_flags() {
|
||||
ReduceH = IH != OH && OH == 1;
|
||||
ReduceW = IW != OW && OW == 1;
|
||||
|
||||
// must be done after the above dimension change
|
||||
create_DH_working_memory();
|
||||
|
||||
ReduceAll_opt = layout == ReduceLayoutType::reduce_blocked && !isDynamicNode() && support_split &&
|
||||
ReduceC && ReduceD && ReduceH && ReduceW;
|
||||
use_aux_kernel = (ReduceDH_opt || ReduceAll_opt) && precision_change;
|
||||
|
||||
// suit for parallel
|
||||
if (ReduceH && IW == 1) {
|
||||
ReduceW = true;
|
||||
@ -2856,6 +2898,16 @@ inline void Reduce::set_reduce_dim_flags() {
|
||||
if (ReduceC && ReduceH && ID == 1) {
|
||||
ReduceD = true;
|
||||
}
|
||||
|
||||
// must be done after the above dimension change
|
||||
create_opt_working_memory();
|
||||
|
||||
ReduceAll_opt = layout == ReduceLayoutType::reduce_blocked && support_split &&
|
||||
ReduceC && ReduceD && ReduceH && ReduceW;
|
||||
if (!set_use_aux_kernel) {
|
||||
use_aux_kernel = (ReduceDH_opt || ReduceCDW_opt || ReduceAll_opt) && precision_change;
|
||||
set_use_aux_kernel = true;
|
||||
}
|
||||
}
|
||||
|
||||
inline void Reduce::reduce_ref(const float *in_ptr, float *out_ptr) {
|
||||
|
@ -112,9 +112,11 @@ private:
|
||||
inline void reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, size_t work_amount,
|
||||
size_t reduce_w = 2, size_t work_batch = 1, const int *tab_idx = NULL);
|
||||
inline void reduce_kernel_post_process(uint8_t *out_ptr);
|
||||
inline void reduce_kernel_reassign();
|
||||
inline void reduce_kernel_restore();
|
||||
inline void init_dst_data(uint8_t *out_ptr, size_t dst_size);
|
||||
inline void create_working_memory();
|
||||
inline void create_DH_working_memory();
|
||||
inline void create_hybrid_working_memory();
|
||||
inline void create_opt_working_memory();
|
||||
inline void calc_process_dst_dims(std::vector<int> &reduce_axes, const InferenceEngine::SizeVector &dst_dim);
|
||||
inline void set_reduce_dim_flags();
|
||||
inline void reduce_ref(const float *in_ptr, float *out_ptr);
|
||||
@ -142,11 +144,13 @@ private:
|
||||
bool precision_change = false;
|
||||
bool ReduceAll_opt = false;
|
||||
bool ReduceDH_opt = false;
|
||||
bool ReduceCDW_opt = false;
|
||||
bool use_aux_kernel = false;
|
||||
bool set_use_aux_kernel = false;
|
||||
bool ReduceN, ReduceC, ReduceD, ReduceH, ReduceW;
|
||||
size_t IB, IC, ID, IH, IW;
|
||||
size_t OB, OC, OD, OH, OW;
|
||||
size_t PD, PW;
|
||||
size_t PD, PH, PW;
|
||||
size_t src_data_size, dst_data_size, prc_data_size;
|
||||
size_t reduce_stride;
|
||||
ReduceLayoutType layout;
|
||||
@ -165,6 +169,7 @@ private:
|
||||
|
||||
dnnl::memory prc_mem;
|
||||
std::vector<uint8_t> vec_reduceDH_prc;
|
||||
std::vector<uint8_t> vec_reduceCDW_prc;
|
||||
|
||||
std::shared_ptr<jit_uni_reduce_kernel> reduce_kernel;
|
||||
std::shared_ptr<jit_uni_reduce_kernel> reduce_aux_kernel;
|
||||
|
@ -215,6 +215,7 @@ const std::vector<std::vector<int>> axesND = {
|
||||
const std::vector<std::vector<int>> axes5D = {
|
||||
{2, 4},
|
||||
{0, 2, 4},
|
||||
{1, 2, 4},
|
||||
{0, 1, 2, 3, 4},
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user