[CPU] Fix performance issue for some cases of Reduce node (#11456)

This commit is contained in:
Chen Xu 2023-07-12 16:48:38 +08:00 committed by GitHub
parent af6c2b0671
commit 7c1949421f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 193 additions and 59 deletions

View File

@ -89,6 +89,7 @@ size_t ReduceKey::hash() const {
size_t seed = 0;
seed = hash_combine(seed, jcp.layout);
seed = hash_combine(seed, jcp.reduce_mode);
seed = hash_combine(seed, jcp.fuse_low_precision);
seed = hash_combine(seed, jcp.src_dt);
seed = hash_combine(seed, jcp.dst_dt);
seed = get_post_op_hash(seed, *postOps.get());
@ -98,6 +99,7 @@ size_t ReduceKey::hash() const {
bool ReduceKey::operator==(const ReduceKey &rhs) const {
return jcp.layout == rhs.jcp.layout && jcp.reduce_mode == rhs.jcp.reduce_mode &&
jcp.fuse_low_precision == rhs.jcp.fuse_low_precision &&
jcp.src_dt == rhs.jcp.src_dt && jcp.dst_dt == rhs.jcp.dst_dt && *postOps.get() == *rhs.postOps.get();
}
} // namespace
@ -192,6 +194,8 @@ private:
Xbyak::Reg64 reg_src_aux = rax;
Xbyak::Reg64 reg_work_batch_aux = rbx;
Xbyak::Reg64 reg_can_divide = rbp;
Xbyak::Reg64 reg_divisor = reg_can_divide;
Vmm vmm_aux = Vmm(0);
Xmm xmm_aux = Xmm(0);
@ -302,6 +306,22 @@ private:
// reduce
reduce_kernel();
if (jcp_.reduce_mode == Algorithm::ReduceMean) {
Xbyak::Label reduce_divide_end_label;
mov(reg_can_divide, ptr[reg_params + GET_OFF(can_divide)]);
cmp(reg_can_divide, 0);
je(reduce_divide_end_label, T_NEAR);
{
mov(reg_divisor, ptr[reg_params + GET_OFF(divisor)]);
uni_vbroadcastss(vmm_aux, ptr[reg_divisor]);
uni_vdivps(vmm_dst, vmm_dst, vmm_aux);
if (isa == cpu::x64::sse41) {
uni_vdivps(vmm_dst_aux, vmm_dst_aux, vmm_aux);
}
}
L(reduce_divide_end_label);
}
// store
store_dst_vector();
@ -1124,14 +1144,19 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
this->preamble();
planar_layout = jcp_.layout == ReduceLayoutType::reduce_ncsp || jcp_.layout == ReduceLayoutType::reduce_nspc;
post_reduce = jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean ||
jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp;
post_ops_fusing = attr_.post_ops_.len() != 0;
mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]);
mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]);
mov(reg_channel_size, ptr[reg_params + GET_OFF_POST(channel_size)]);
mov(reg_divisor, ptr[reg_params + GET_OFF_POST(divisor)]);
if (jcp_.fuse_low_precision)
mov(reg_src, ptr[reg_params + GET_OFF_POST(src)]);
if (!planar_layout)
mov(reg_reduce_c, ptr[reg_params + GET_OFF_POST(reduce_c)]);
if (attr_.post_ops_.len() != 0) {
if (post_ops_fusing) {
mov(reg_post_ops_data, ptr[reg_params + GET_OFF_POST(post_op_data)]);
mov(reg_oc_off, ptr[reg_params + GET_OFF_POST(oc_off)]);
}
@ -1141,7 +1166,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
if (jcp_.layout == ReduceLayoutType::reduce_blocked) {
reduce_post_main();
} else if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0) {
} else if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing) {
// the tail of channel dimension should always be concerned during post ops fusing for nspc layout
Xbyak::Label reduce_nspc_loop_label;
Xbyak::Label reduce_nspc_loop_end_label;
@ -1183,7 +1208,10 @@ private:
Xbyak::Ymm, Xbyak::Zmm>::type;
size_t vlen = cpu_isa_traits<isa>::vlen;
bool planar_layout = false;
bool post_reduce = true;
bool post_ops_fusing = false;
Xbyak::Reg64 reg_src = rbp;
Xbyak::Reg64 reg_dst = r8;
Xbyak::Reg64 reg_work_amount = r9;
Xbyak::Reg64 reg_total_work_amount = r10;
@ -1246,9 +1274,9 @@ private:
jl(reduce_loop_end_label, T_NEAR);
// load
load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt);
wrap_load_vector(vmm_dst, 0);
if (isa == cpu::x64::sse41)
load_vector(vmm_dst_aux, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt);
wrap_load_vector(vmm_dst_aux, 4);
// reduce and store
horiz_reduce_store(vmm_dst, jcp_.dst_dt);
@ -1256,22 +1284,27 @@ private:
horiz_reduce_store(vmm_dst_aux, jcp_.dst_dt, true);
add(reg_dst, step * jcp_.dst_data_size);
if (jcp_.fuse_low_precision)
add(reg_src, step * sizeof(float));
sub(reg_work_amount, step);
jmp(reduce_loop_label, T_NEAR);
}
L(reduce_loop_end_label);
mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]);
mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]);
if (post_reduce || post_ops_fusing) {
mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]);
if (jcp_.fuse_low_precision)
mov(reg_src, ptr[reg_params + GET_OFF_POST(src)]);
mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]);
}
}
// reduce map for value in dst memory
// cases: [ReduceL2] [ReduceLogSum] [ReduceLogSumExp] [ReduceMean]
L(reduce_map_label);
{
if (jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean ||
jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp) {
if (post_reduce) {
if (jcp_.reduce_mode == Algorithm::ReduceMean)
uni_vbroadcastss(vmm_aux, ptr[reg_divisor]);
@ -1284,16 +1317,16 @@ private:
cmp(reg_work_amount, step);
jl(reduce_loop_end_label, T_NEAR);
load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt);
wrap_load_vector(vmm_dst, 0);
reduce_map_kernel(vmm_dst);
if (attr_.post_ops_.len() != 0)
if (post_ops_fusing)
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
store_vector(ptr[reg_dst], vmm_dst, jcp_.dst_dt);
if (isa == cpu::x64::sse41) {
load_vector(vmm_dst, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt);
wrap_load_vector(vmm_dst, 4);
reduce_map_kernel(vmm_dst);
if (attr_.post_ops_.len() != 0) {
if (post_ops_fusing) {
if (jcp_.layout != ReduceLayoutType::reduce_ncsp)
add(reg_oc_off, 4 * sizeof(float));
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
@ -1304,7 +1337,9 @@ private:
}
add(reg_dst, step * jcp_.dst_data_size);
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
if (jcp_.fuse_low_precision)
add(reg_src, step * sizeof(float));
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
add(reg_oc_off, step * sizeof(float));
sub(reg_work_amount, step);
@ -1312,7 +1347,7 @@ private:
}
L(reduce_loop_end_label);
} else {
if (attr_.post_ops_.len() != 0) {
if (post_ops_fusing) {
Xbyak::Label reduce_loop_label;
Xbyak::Label reduce_loop_end_label;
@ -1322,12 +1357,12 @@ private:
cmp(reg_work_amount, step);
jl(reduce_loop_end_label, T_NEAR);
load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt);
wrap_load_vector(vmm_dst, 0);
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
store_vector(ptr[reg_dst], vmm_dst, jcp_.dst_dt);
if (isa == cpu::x64::sse41) {
load_vector(vmm_dst, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt);
wrap_load_vector(vmm_dst, 4);
if (jcp_.layout != ReduceLayoutType::reduce_ncsp)
add(reg_oc_off, 4 * sizeof(float));
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
@ -1337,7 +1372,9 @@ private:
}
add(reg_dst, step * jcp_.dst_data_size);
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
if (jcp_.fuse_low_precision)
add(reg_src, step * sizeof(float));
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
add(reg_oc_off, step * sizeof(float));
sub(reg_work_amount, step);
@ -1352,8 +1389,7 @@ private:
inline void reduce_post_tail() {
// reduce map for tail in dst memory
// cases: [ReduceL2] [ReduceLogSum] [ReduceLogSumExp] [ReduceMean] in planar layout
if (jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean ||
jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp) {
if (post_reduce) {
if (jcp_.reduce_mode == Algorithm::ReduceMean)
uni_vbroadcastss(xmm_aux, ptr[reg_divisor]);
@ -1367,18 +1403,20 @@ private:
jl(reduce_loop_end_label, T_NEAR);
// load
load_scalar(xmm_dst, ptr[reg_dst], jcp_.dst_dt);
wrap_load_scalar(xmm_dst, 0);
// reduce
reduce_map_kernel_scalar(xmm_dst);
// store
if (attr_.post_ops_.len() != 0)
if (post_ops_fusing)
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
store_scalar(ptr[reg_dst], xmm_dst, jcp_.dst_dt);
add(reg_dst, step * jcp_.dst_data_size);
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
if (jcp_.fuse_low_precision)
add(reg_src, step * sizeof(float));
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
add(reg_oc_off, step * sizeof(float));
sub(reg_work_amount, step);
@ -1386,7 +1424,7 @@ private:
}
L(reduce_loop_end_label);
} else {
if (attr_.post_ops_.len() != 0) {
if (post_ops_fusing) {
Xbyak::Label reduce_loop_label;
Xbyak::Label reduce_loop_end_label;
@ -1397,14 +1435,16 @@ private:
jl(reduce_loop_end_label, T_NEAR);
// load
load_scalar(xmm_dst, ptr[reg_dst], jcp_.dst_dt);
wrap_load_scalar(xmm_dst, 0);
// store
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
store_scalar(ptr[reg_dst], xmm_dst, jcp_.dst_dt);
add(reg_dst, step * jcp_.dst_data_size);
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
if (jcp_.fuse_low_precision)
add(reg_src, step * sizeof(float));
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
add(reg_oc_off, step * sizeof(float));
sub(reg_work_amount, step);
@ -1476,6 +1516,20 @@ private:
log_injector->compute_vector_range(xmm_dst.getIdx(), xmm_dst.getIdx() + 1);
}
inline void wrap_load_vector(Vmm vmm_val, size_t offset) {
if (jcp_.fuse_low_precision)
load_vector(vmm_val, ptr[reg_src + offset * sizeof(float)], memory::data_type::f32);
else
load_vector(vmm_val, ptr[reg_dst + offset * jcp_.dst_data_size], jcp_.dst_dt);
}
inline void wrap_load_scalar(Xmm xmm_val, size_t offset) {
if (jcp_.fuse_low_precision)
load_scalar(xmm_val, ptr[reg_src + offset * sizeof(float)], memory::data_type::f32);
else
load_scalar(xmm_val, ptr[reg_dst + offset * jcp_.dst_data_size], jcp_.dst_dt);
}
inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
switch (src_dt) {
case memory::data_type::f32:
@ -1636,11 +1690,19 @@ private:
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_vmovhlps(xmm_aux3, xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
if (load_embedded) {
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
if (jcp_.fuse_low_precision && (post_reduce || post_ops_fusing)) {
if (load_embedded) {
load_scalar(xmm_aux3, ptr[reg_src], memory::data_type::f32);
horiz_ps(xmm_dst, xmm_aux3);
}
store_scalar(ptr[reg_src], xmm_dst, memory::data_type::f32);
} else {
if (load_embedded) {
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
}
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
}
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
}
inline void horiz_ps(const Xmm& xmm, const Operand& op) {
@ -1762,6 +1824,7 @@ Reduce::Reduce(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
raw_axes = reduceConst->cast_vector<int>();
}
set_use_aux_kernel = false;
fuse_low_precision = false;
vec_reduceDH_prc.clear();
vec_reduceCDW_prc.clear();
setJITBeyond5D();
@ -1800,7 +1863,17 @@ void Reduce::initSupportedPrimitiveDescriptors() {
output_prec = getOriginalOutputPrecisionAtPort(0);
if (!fusedWith.empty()) {
output_prec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
// In jit mode we use the output memory as an intermediate accumulator for certain reduce modes.
// If the post ops node has a lower precision for such modes, working buffer with original precision is needed,
// in order to avoid accuracy loss.
auto fused_prec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
if (output_prec == Precision::FP32 && fused_prec != Precision::FP32) {
if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr &&
algorithm != Algorithm::ReduceMin && algorithm != Algorithm::ReduceMax) {
fuse_low_precision = true;
}
}
output_prec = fused_prec;
}
jit_mode = canApplyJIT(input_prec, output_prec);
@ -1818,12 +1891,14 @@ void Reduce::initSupportedPrimitiveDescriptors() {
}
}
precision_change = input_prec != output_prec;
intermediate_prec = fuse_low_precision ? Precision(Precision::FP32) : output_prec;
precision_change = input_prec != intermediate_prec;
support_split = algorithm != Algorithm::ReduceL2 && algorithm != Algorithm::ReduceLogSumExp &&
algorithm != Algorithm::ReduceSumSquare;
src_data_size = input_prec.size();
dst_data_size = output_prec.size();
intermediate_data_size = intermediate_prec.size();
NodeConfig config;
config.inConfs.resize(2);
@ -1950,6 +2025,9 @@ void Reduce::prepareParams() {
set_reduce_dim_flags();
}
apply_post_kernel = true;
apply_division = false;
auto builder = [&](const ReduceKey& key) -> std::shared_ptr<jit_uni_reduce_post_kernel> {
std::shared_ptr<jit_uni_reduce_post_kernel> post_kernel;
#if defined(OPENVINO_ARCH_X86_64)
@ -2021,6 +2099,7 @@ void Reduce::createPrimitive() {
jcp.dst_data_size = DnnlExtensionUtils::sizeOfDataType(jcp.dst_dt);
jcp.layout = layout;
jcp.reduce_mode = getAlgorithm();
jcp.fuse_low_precision = fuse_low_precision;
#if defined(OPENVINO_ARCH_X86_64)
compile_post_kernel = true;
@ -2040,7 +2119,10 @@ void Reduce::createPrimitive() {
updateLastInputDims();
}
create_reduce_kernel(reduce_kernel, jcp);
auto reduce_jcp = jcp;
reduce_jcp.dst_dt = fuse_low_precision ? DnnlExtensionUtils::IEPrecisionToDataType(intermediate_prec) : jcp.dst_dt;
jcp.dst_data_size = DnnlExtensionUtils::sizeOfDataType(reduce_jcp.dst_dt);
create_reduce_kernel(reduce_kernel, reduce_jcp);
// set_use_aux_kernel being false means this is a dynamic case, and prepareParams() hasn't been invoked yet.
// So set use_aux_kernel true if precision changes, in case ReduceDH_opt, ReduceCDW_opt or ReduceAll_opt
@ -2056,9 +2138,9 @@ void Reduce::createPrimitive() {
// stage to reduce some dimensions, and an extra fp32-in-fp32-out aux kernel will be applied on the second
// stage to reduce the rest dimensions.
if (use_aux_kernel) {
aux_jcp = jcp;
aux_jcp.src_dt = jcp.dst_dt;
aux_jcp.src_data_size = jcp.dst_data_size;
aux_jcp = reduce_jcp;
aux_jcp.src_dt = reduce_jcp.dst_dt;
aux_jcp.src_data_size = reduce_jcp.dst_data_size;
create_reduce_kernel(reduce_aux_kernel, aux_jcp);
}
}
@ -2093,7 +2175,7 @@ void Reduce::execute(dnnl::stream strm) {
if (is_hybrid_layout) {
dst_data = reinterpret_cast<uint8_t *>(prc_mem.get_data_handle());
}
reduce_type(src_data, dst_data, dst_size);
reduce_type(src_data, dst_data);
} else if (aclExecPtr) {
std::vector<MemoryCPtr> srcMemory;
for (size_t i = 0; i < getParentEdges().size(); i++) {
@ -2114,8 +2196,7 @@ void Reduce::execute(dnnl::stream strm) {
}
}
void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_size) {
init_dst_data(out_ptr, dst_size);
void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr) {
reduce_stride = IW;
if (layout == ReduceLayoutType::reduce_ncsp || layout == ReduceLayoutType::reduce_nspc) {
@ -2141,6 +2222,9 @@ void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_siz
}
void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
output_info_reassign(out_ptr);
init_dst_data(out_ptr, dst_size);
if (ReduceN && !ReduceC && !ReduceD && !ReduceH && !ReduceW) {
size_t IA = IC * ID * IH * IW;
reduce_stride = IA;
@ -2348,24 +2432,29 @@ void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
}
}
output_info_restore(&out_ptr);
reduce_kernel_post_process(out_ptr);
}
void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
size_t ICB = div_up(IC, blk_size);
size_t OCB = div_up(OC, blk_size);
output_info_reassign(out_ptr);
init_dst_data(out_ptr, dst_size);
for (size_t ib = 0; ib < IB; ib++) {
size_t ob = ReduceN ? 0 : ib; GET_PTR_N_BLK;
if (!ReduceC && !ReduceD && ReduceH && ReduceW) {
if (!ReduceN || (ReduceN && ib == IB - 1)) {
apply_division = getAlgorithm() == Algorithm::ReduceMean && attr.get()->post_ops_.len() == 0;
apply_post_kernel = !apply_division;
}
parallel_for2d(ICB, ID, [&](size_t icb, size_t id) {
size_t ocb = icb, od = id; GET_PTR_NCD_BASE_PTR_N_BLK;
reduce_kernel_process(in_ptr_ncd, out_ptr_ncd, IH * IW * blk_size);
});
} else if (ReduceC && ReduceD && ReduceH && ReduceW) {
if (!ReduceAll_opt) {
reduce_kernel_process(in_ptr_n, out_ptr_n, ICB * ID * IH * IW * blk_size);
} else {
if (ReduceAll_opt) {
// reduce parallelly
// step1: !ReduceC && ReduceD && ReduceH && ReduceW
size_t prc_size = ICB * blk_size * dst_data_size;
@ -2381,6 +2470,8 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
reduce_kernel_reassign();
reduce_kernel_process(out_ptr_n, out_ptr_n_cp, ICB * blk_size);
reduce_kernel_restore();
} else {
reduce_kernel_process(in_ptr_n, out_ptr_n, ICB * ID * IH * IW * blk_size);
}
} else if (ReduceW) {
for (size_t icb = 0; icb < ICB; icb++) {
@ -2419,12 +2510,17 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
}
}
reduce_kernel_post_process(out_ptr);
output_info_restore(&out_ptr);
if (apply_post_kernel) {
reduce_kernel_post_process(out_ptr);
}
}
void Reduce::reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr) {
size_t ICB = div_up(IC, blk_size);
size_t OCB = div_up(OC, blk_size);
output_info_reassign(out_ptr);
init_dst_data(out_ptr, dst_size);
auto reduceSkipPadding = [&](const uint8_t *in_ptr_ncd, uint8_t *out_ptr_ncd, size_t ic) {
size_t blk_valid_size = IC - ic;
@ -2503,11 +2599,13 @@ void Reduce::reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr)
}
}
output_info_restore(&out_ptr);
reduce_kernel_post_process(out_ptr);
}
inline void Reduce::reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, size_t work_amount,
size_t reduce_w, size_t work_batch, const int *tab_idx) {
const float divisor = apply_division ? static_cast<float>(IB * IC * ID * IH * IW / (OB * OC * OD * OH * OW)) : 1;
auto arg = jit_reduce_call_args();
arg.src = static_cast<const void *>(in_p);
arg.idx = tab_idx;
@ -2516,17 +2614,22 @@ inline void Reduce::reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, s
arg.work_batch = work_batch;
arg.reduce_w = reduce_w;
arg.reduce_stride = reduce_stride;
arg.can_divide = apply_division ? 1 : 0;
arg.divisor = &divisor;
(*reduce_kernel)(&arg);
}
inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
const uint8_t *in_ptr = fuse_low_precision ? static_cast<uint8_t *>(&intermediate_buf[0]) : nullptr;
const size_t integerDivisor = IB * IC * ID * IH * IW / (OB * OC * OD * OH * OW);
const float divisor = static_cast<float>(integerDivisor);
if (layout == ReduceLayoutType::reduce_ncsp) {
parallel_for2d(OB, OC, [&](size_t ob, size_t oc) {
const uint8_t *in_p = in_ptr + (ob * OC + oc) * OD * OH * OW * intermediate_data_size;
uint8_t *out_p = out_ptr + (ob * OC + oc) * OD * OH * OW * dst_data_size;
auto arg = jit_reduce_post_call_args();
arg.src = static_cast<const void *>(in_p);
arg.dst = static_cast<void *>(out_p);
arg.oc_off = oc * sizeof(float);
arg.channel_size = OC;
@ -2542,8 +2645,10 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
OP *= OH;
size_t work_amount = OB * OC * OD * OH * OW / OP;
parallel_for(OP, [&](size_t op) {
const uint8_t *in_p = in_ptr + op * work_amount * intermediate_data_size;
uint8_t *out_p = out_ptr + op * work_amount * dst_data_size;
auto arg = jit_reduce_post_call_args();
arg.src = static_cast<const void *>(in_p);
arg.dst = static_cast<void *>(out_p);
arg.oc_off = 0;
arg.channel_size = OW; // OW is related to nspc-ncsp dimension reinterpret
@ -2555,8 +2660,10 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
} else {
size_t OCB = div_up(OC, blk_size);
parallel_for2d(OB, OCB, [&](size_t ob, size_t ocb) {
const uint8_t *in_p = in_ptr + (ob * OCB + ocb) * OD * OH * OW * blk_size * intermediate_data_size;
uint8_t *out_p = out_ptr + (ob * OCB + ocb) * OD * OH * OW * blk_size * dst_data_size;
auto arg = jit_reduce_post_call_args();
arg.src = static_cast<const void *>(in_p);
arg.dst = static_cast<void *>(out_p);
arg.reduce_c = ReduceC ? 1 : 0;
arg.oc_off = ocb * blk_size * sizeof(float);
@ -2580,6 +2687,27 @@ inline void Reduce::reduce_kernel_restore() {
}
}
inline void Reduce::output_info_reassign(uint8_t *out_ptr) {
if (fuse_low_precision) {
tmp_ptr = out_ptr;
out_ptr = static_cast<uint8_t *>(&intermediate_buf[0]);
tmp_prec = output_prec;
output_prec = intermediate_prec;
tmp_data_size = dst_data_size;
dst_data_size = intermediate_data_size;
tmp_size = dst_size;
dst_size = intermediate_size;
}
}
inline void Reduce::output_info_restore(uint8_t **out_ptr) {
if (fuse_low_precision) {
*out_ptr = tmp_ptr;
output_prec = tmp_prec;
dst_data_size = tmp_data_size;
dst_size = tmp_size;
}
}
void Reduce::nspc2ncsp(uint8_t *proc_ptr, uint8_t *out_ptr) {
// dimension reinterpret after nspc reusing routine reduce_PLN
// demote -- nspc -- ncsp
@ -2795,12 +2923,19 @@ inline void Reduce::create_hybrid_working_memory() {
}
inline void Reduce::create_opt_working_memory() {
if (fuse_low_precision) {
intermediate_size = dst_size * sizeof(float) / dst_data_size;
if (intermediate_size > intermediate_buf.size()) {
intermediate_buf.resize(intermediate_size);
}
}
ReduceDH_opt = layout == ReduceLayoutType::reduce_nspc && support_split &&
!ReduceC && ReduceD && ReduceH && !ReduceW && IC == 1 && ID > 1;
if (ReduceDH_opt) {
PD = ID;
PW = IW / blk_size * blk_size;
prc_data_size = dst_data_size;
prc_data_size = intermediate_data_size;
prc_size = PD * PW * prc_data_size;
if (prc_size > vec_reduceDH_prc.size()) {
vec_reduceDH_prc.resize(prc_size);
@ -2813,7 +2948,7 @@ inline void Reduce::create_opt_working_memory() {
if (ReduceCDW_opt) {
PH = IH;
PW = IW;
prc_data_size = dst_data_size;
prc_data_size = intermediate_data_size;
prc_size = PH * PW * prc_data_size;
if (prc_size > vec_reduceCDW_prc.size()) {
vec_reduceCDW_prc.resize(prc_size);
@ -3168,16 +3303,6 @@ bool Reduce::canFuse(const NodePtr& node) const {
return false;
}
// In jit mode we use the output memory as an intermediate accumulator for certain reduce modes.
// If the post ops node has a lower precision for such modes, post ops fusing won't be supposted, in order to avoid accuracy loss.
if (output_prec == Precision::FP32 &&
!node->getOriginalOutputPrecisions().empty() && node->getOriginalOutputPrecisionAtPort(0) != Precision::FP32) {
if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr &&
algorithm != Algorithm::ReduceMin && algorithm != Algorithm::ReduceMax) {
return false;
}
}
return canFuseSimpleOperation(node);
}

View File

@ -24,6 +24,7 @@ enum ReduceLayoutType {
struct jit_reduce_config_params {
ReduceLayoutType layout;
Algorithm reduce_mode;
bool fuse_low_precision;
dnnl::memory::data_type src_dt;
dnnl::memory::data_type dst_dt;
int src_data_size;
@ -38,6 +39,8 @@ struct jit_reduce_call_args {
size_t work_batch;
size_t reduce_w = 2; // only used in planar layout [1: reduce width dimension] [0: reduce other dimension] [other value: N/A]
size_t reduce_stride; // only used in planar layout while reducing dimensions except for width
size_t can_divide; // if apply division in reduce_kernel [1: Yes] [0: No]
const float *divisor; // mean = sum / divisor
};
struct jit_reduce_post_call_args {
@ -105,7 +108,7 @@ public:
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
private:
void reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_size);
void reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr);
void reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr);
void reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr);
void reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr);
@ -114,6 +117,8 @@ private:
inline void reduce_kernel_post_process(uint8_t *out_ptr);
inline void reduce_kernel_reassign();
inline void reduce_kernel_restore();
inline void output_info_reassign(uint8_t *out_ptr);
inline void output_info_restore(uint8_t **out_ptr);
inline void init_dst_data(uint8_t *out_ptr, size_t dst_size);
inline void create_hybrid_working_memory();
inline void create_opt_working_memory();
@ -131,8 +136,6 @@ private:
bool canApplyJIT(const InferenceEngine::Precision &input_prec, const InferenceEngine::Precision &output_prec) const;
size_t blk_size;
size_t dst_size;
size_t prc_size;
static const size_t REDUCE_DATA = 0;
static const size_t REDUCE_INDEXES = 1;
bool jit_beyond_5D = false;
@ -140,6 +143,9 @@ private:
bool keep_dims = true;
bool is_hybrid_layout = false;
bool compile_post_kernel = true;
bool apply_post_kernel = true;
bool apply_division = false;
bool fuse_low_precision = false;
bool support_split = false;
bool precision_change = false;
bool ReduceAll_opt = false;
@ -151,14 +157,17 @@ private:
size_t IB, IC, ID, IH, IW;
size_t OB, OC, OD, OH, OW;
size_t PD, PH, PW;
size_t src_data_size, dst_data_size, prc_data_size;
size_t src_data_size, dst_data_size, prc_data_size, intermediate_data_size, tmp_data_size;
size_t dst_size, prc_size, intermediate_size, tmp_size;
size_t reduce_stride;
uint8_t *tmp_ptr;
ReduceLayoutType layout;
InferenceEngine::Precision input_prec, output_prec;
InferenceEngine::Precision input_prec, output_prec, intermediate_prec, tmp_prec;
InferenceEngine::SizeVector src_dims;
InferenceEngine::SizeVector process_dst_dims;
InferenceEngine::SizeVector axes_for_reduction;
std::vector<int> raw_axes;
std::vector<uint8_t> intermediate_buf;
jit_reduce_config_params jcp;
jit_reduce_config_params aux_jcp;