From 7c1949421f804b53636f5f03877872a327853705 Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Wed, 12 Jul 2023 16:48:38 +0800 Subject: [PATCH] [CPU] Fix performance issue for some cases of Reduce node (#11456) --- src/plugins/intel_cpu/src/nodes/reduce.cpp | 233 ++++++++++++++++----- src/plugins/intel_cpu/src/nodes/reduce.h | 19 +- 2 files changed, 193 insertions(+), 59 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/reduce.cpp b/src/plugins/intel_cpu/src/nodes/reduce.cpp index 9c440c09ef4..bb3992f98cb 100644 --- a/src/plugins/intel_cpu/src/nodes/reduce.cpp +++ b/src/plugins/intel_cpu/src/nodes/reduce.cpp @@ -89,6 +89,7 @@ size_t ReduceKey::hash() const { size_t seed = 0; seed = hash_combine(seed, jcp.layout); seed = hash_combine(seed, jcp.reduce_mode); + seed = hash_combine(seed, jcp.fuse_low_precision); seed = hash_combine(seed, jcp.src_dt); seed = hash_combine(seed, jcp.dst_dt); seed = get_post_op_hash(seed, *postOps.get()); @@ -98,6 +99,7 @@ size_t ReduceKey::hash() const { bool ReduceKey::operator==(const ReduceKey &rhs) const { return jcp.layout == rhs.jcp.layout && jcp.reduce_mode == rhs.jcp.reduce_mode && + jcp.fuse_low_precision == rhs.jcp.fuse_low_precision && jcp.src_dt == rhs.jcp.src_dt && jcp.dst_dt == rhs.jcp.dst_dt && *postOps.get() == *rhs.postOps.get(); } } // namespace @@ -192,6 +194,8 @@ private: Xbyak::Reg64 reg_src_aux = rax; Xbyak::Reg64 reg_work_batch_aux = rbx; + Xbyak::Reg64 reg_can_divide = rbp; + Xbyak::Reg64 reg_divisor = reg_can_divide; Vmm vmm_aux = Vmm(0); Xmm xmm_aux = Xmm(0); @@ -302,6 +306,22 @@ private: // reduce reduce_kernel(); + if (jcp_.reduce_mode == Algorithm::ReduceMean) { + Xbyak::Label reduce_divide_end_label; + mov(reg_can_divide, ptr[reg_params + GET_OFF(can_divide)]); + cmp(reg_can_divide, 0); + je(reduce_divide_end_label, T_NEAR); + { + mov(reg_divisor, ptr[reg_params + GET_OFF(divisor)]); + uni_vbroadcastss(vmm_aux, ptr[reg_divisor]); + uni_vdivps(vmm_dst, vmm_dst, vmm_aux); + if (isa == cpu::x64::sse41) { + uni_vdivps(vmm_dst_aux, vmm_dst_aux, vmm_aux); + } + } + L(reduce_divide_end_label); + } + // store store_dst_vector(); @@ -1124,14 +1144,19 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi this->preamble(); planar_layout = jcp_.layout == ReduceLayoutType::reduce_ncsp || jcp_.layout == ReduceLayoutType::reduce_nspc; + post_reduce = jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean || + jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp; + post_ops_fusing = attr_.post_ops_.len() != 0; mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]); mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]); mov(reg_channel_size, ptr[reg_params + GET_OFF_POST(channel_size)]); mov(reg_divisor, ptr[reg_params + GET_OFF_POST(divisor)]); + if (jcp_.fuse_low_precision) + mov(reg_src, ptr[reg_params + GET_OFF_POST(src)]); if (!planar_layout) mov(reg_reduce_c, ptr[reg_params + GET_OFF_POST(reduce_c)]); - if (attr_.post_ops_.len() != 0) { + if (post_ops_fusing) { mov(reg_post_ops_data, ptr[reg_params + GET_OFF_POST(post_op_data)]); mov(reg_oc_off, ptr[reg_params + GET_OFF_POST(oc_off)]); } @@ -1141,7 +1166,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi if (jcp_.layout == ReduceLayoutType::reduce_blocked) { reduce_post_main(); - } else if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0) { + } else if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing) { // the tail of channel dimension should always be concerned during post ops fusing for nspc layout Xbyak::Label reduce_nspc_loop_label; Xbyak::Label reduce_nspc_loop_end_label; @@ -1183,7 +1208,10 @@ private: Xbyak::Ymm, Xbyak::Zmm>::type; size_t vlen = cpu_isa_traits::vlen; bool planar_layout = false; + bool post_reduce = true; + bool post_ops_fusing = false; + Xbyak::Reg64 reg_src = rbp; Xbyak::Reg64 reg_dst = r8; Xbyak::Reg64 reg_work_amount = r9; Xbyak::Reg64 reg_total_work_amount = r10; @@ -1246,9 +1274,9 @@ private: jl(reduce_loop_end_label, T_NEAR); // load - load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt); + wrap_load_vector(vmm_dst, 0); if (isa == cpu::x64::sse41) - load_vector(vmm_dst_aux, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt); + wrap_load_vector(vmm_dst_aux, 4); // reduce and store horiz_reduce_store(vmm_dst, jcp_.dst_dt); @@ -1256,22 +1284,27 @@ private: horiz_reduce_store(vmm_dst_aux, jcp_.dst_dt, true); add(reg_dst, step * jcp_.dst_data_size); + if (jcp_.fuse_low_precision) + add(reg_src, step * sizeof(float)); sub(reg_work_amount, step); jmp(reduce_loop_label, T_NEAR); } L(reduce_loop_end_label); - mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]); - mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]); + if (post_reduce || post_ops_fusing) { + mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]); + if (jcp_.fuse_low_precision) + mov(reg_src, ptr[reg_params + GET_OFF_POST(src)]); + mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]); + } } // reduce map for value in dst memory // cases: [ReduceL2] [ReduceLogSum] [ReduceLogSumExp] [ReduceMean] L(reduce_map_label); { - if (jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean || - jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp) { + if (post_reduce) { if (jcp_.reduce_mode == Algorithm::ReduceMean) uni_vbroadcastss(vmm_aux, ptr[reg_divisor]); @@ -1284,16 +1317,16 @@ private: cmp(reg_work_amount, step); jl(reduce_loop_end_label, T_NEAR); - load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt); + wrap_load_vector(vmm_dst, 0); reduce_map_kernel(vmm_dst); - if (attr_.post_ops_.len() != 0) + if (post_ops_fusing) apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp); store_vector(ptr[reg_dst], vmm_dst, jcp_.dst_dt); if (isa == cpu::x64::sse41) { - load_vector(vmm_dst, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt); + wrap_load_vector(vmm_dst, 4); reduce_map_kernel(vmm_dst); - if (attr_.post_ops_.len() != 0) { + if (post_ops_fusing) { if (jcp_.layout != ReduceLayoutType::reduce_ncsp) add(reg_oc_off, 4 * sizeof(float)); apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp); @@ -1304,7 +1337,9 @@ private: } add(reg_dst, step * jcp_.dst_data_size); - if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0) + if (jcp_.fuse_low_precision) + add(reg_src, step * sizeof(float)); + if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing) add(reg_oc_off, step * sizeof(float)); sub(reg_work_amount, step); @@ -1312,7 +1347,7 @@ private: } L(reduce_loop_end_label); } else { - if (attr_.post_ops_.len() != 0) { + if (post_ops_fusing) { Xbyak::Label reduce_loop_label; Xbyak::Label reduce_loop_end_label; @@ -1322,12 +1357,12 @@ private: cmp(reg_work_amount, step); jl(reduce_loop_end_label, T_NEAR); - load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt); + wrap_load_vector(vmm_dst, 0); apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp); store_vector(ptr[reg_dst], vmm_dst, jcp_.dst_dt); if (isa == cpu::x64::sse41) { - load_vector(vmm_dst, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt); + wrap_load_vector(vmm_dst, 4); if (jcp_.layout != ReduceLayoutType::reduce_ncsp) add(reg_oc_off, 4 * sizeof(float)); apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp); @@ -1337,7 +1372,9 @@ private: } add(reg_dst, step * jcp_.dst_data_size); - if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0) + if (jcp_.fuse_low_precision) + add(reg_src, step * sizeof(float)); + if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing) add(reg_oc_off, step * sizeof(float)); sub(reg_work_amount, step); @@ -1352,8 +1389,7 @@ private: inline void reduce_post_tail() { // reduce map for tail in dst memory // cases: [ReduceL2] [ReduceLogSum] [ReduceLogSumExp] [ReduceMean] in planar layout - if (jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean || - jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp) { + if (post_reduce) { if (jcp_.reduce_mode == Algorithm::ReduceMean) uni_vbroadcastss(xmm_aux, ptr[reg_divisor]); @@ -1367,18 +1403,20 @@ private: jl(reduce_loop_end_label, T_NEAR); // load - load_scalar(xmm_dst, ptr[reg_dst], jcp_.dst_dt); + wrap_load_scalar(xmm_dst, 0); // reduce reduce_map_kernel_scalar(xmm_dst); // store - if (attr_.post_ops_.len() != 0) + if (post_ops_fusing) apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp); store_scalar(ptr[reg_dst], xmm_dst, jcp_.dst_dt); add(reg_dst, step * jcp_.dst_data_size); - if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0) + if (jcp_.fuse_low_precision) + add(reg_src, step * sizeof(float)); + if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing) add(reg_oc_off, step * sizeof(float)); sub(reg_work_amount, step); @@ -1386,7 +1424,7 @@ private: } L(reduce_loop_end_label); } else { - if (attr_.post_ops_.len() != 0) { + if (post_ops_fusing) { Xbyak::Label reduce_loop_label; Xbyak::Label reduce_loop_end_label; @@ -1397,14 +1435,16 @@ private: jl(reduce_loop_end_label, T_NEAR); // load - load_scalar(xmm_dst, ptr[reg_dst], jcp_.dst_dt); + wrap_load_scalar(xmm_dst, 0); // store apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp); store_scalar(ptr[reg_dst], xmm_dst, jcp_.dst_dt); add(reg_dst, step * jcp_.dst_data_size); - if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0) + if (jcp_.fuse_low_precision) + add(reg_src, step * sizeof(float)); + if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing) add(reg_oc_off, step * sizeof(float)); sub(reg_work_amount, step); @@ -1476,6 +1516,20 @@ private: log_injector->compute_vector_range(xmm_dst.getIdx(), xmm_dst.getIdx() + 1); } + inline void wrap_load_vector(Vmm vmm_val, size_t offset) { + if (jcp_.fuse_low_precision) + load_vector(vmm_val, ptr[reg_src + offset * sizeof(float)], memory::data_type::f32); + else + load_vector(vmm_val, ptr[reg_dst + offset * jcp_.dst_data_size], jcp_.dst_dt); + } + + inline void wrap_load_scalar(Xmm xmm_val, size_t offset) { + if (jcp_.fuse_low_precision) + load_scalar(xmm_val, ptr[reg_src + offset * sizeof(float)], memory::data_type::f32); + else + load_scalar(xmm_val, ptr[reg_dst + offset * jcp_.dst_data_size], jcp_.dst_dt); + } + inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) { switch (src_dt) { case memory::data_type::f32: @@ -1636,11 +1690,19 @@ private: horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4) uni_vmovhlps(xmm_aux3, xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4 horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),... - if (load_embedded) { - load_scalar(xmm_aux3, ptr[reg_dst], dst_dt); - horiz_ps(xmm_dst, xmm_aux3); + if (jcp_.fuse_low_precision && (post_reduce || post_ops_fusing)) { + if (load_embedded) { + load_scalar(xmm_aux3, ptr[reg_src], memory::data_type::f32); + horiz_ps(xmm_dst, xmm_aux3); + } + store_scalar(ptr[reg_src], xmm_dst, memory::data_type::f32); + } else { + if (load_embedded) { + load_scalar(xmm_aux3, ptr[reg_dst], dst_dt); + horiz_ps(xmm_dst, xmm_aux3); + } + store_scalar(ptr[reg_dst], xmm_dst, dst_dt); } - store_scalar(ptr[reg_dst], xmm_dst, dst_dt); } inline void horiz_ps(const Xmm& xmm, const Operand& op) { @@ -1762,6 +1824,7 @@ Reduce::Reduce(const std::shared_ptr& op, const GraphContext::CPtr raw_axes = reduceConst->cast_vector(); } set_use_aux_kernel = false; + fuse_low_precision = false; vec_reduceDH_prc.clear(); vec_reduceCDW_prc.clear(); setJITBeyond5D(); @@ -1800,7 +1863,17 @@ void Reduce::initSupportedPrimitiveDescriptors() { output_prec = getOriginalOutputPrecisionAtPort(0); if (!fusedWith.empty()) { - output_prec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0); + // In jit mode we use the output memory as an intermediate accumulator for certain reduce modes. + // If the post ops node has a lower precision for such modes, working buffer with original precision is needed, + // in order to avoid accuracy loss. + auto fused_prec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0); + if (output_prec == Precision::FP32 && fused_prec != Precision::FP32) { + if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr && + algorithm != Algorithm::ReduceMin && algorithm != Algorithm::ReduceMax) { + fuse_low_precision = true; + } + } + output_prec = fused_prec; } jit_mode = canApplyJIT(input_prec, output_prec); @@ -1818,12 +1891,14 @@ void Reduce::initSupportedPrimitiveDescriptors() { } } - precision_change = input_prec != output_prec; + intermediate_prec = fuse_low_precision ? Precision(Precision::FP32) : output_prec; + precision_change = input_prec != intermediate_prec; support_split = algorithm != Algorithm::ReduceL2 && algorithm != Algorithm::ReduceLogSumExp && algorithm != Algorithm::ReduceSumSquare; src_data_size = input_prec.size(); dst_data_size = output_prec.size(); + intermediate_data_size = intermediate_prec.size(); NodeConfig config; config.inConfs.resize(2); @@ -1950,6 +2025,9 @@ void Reduce::prepareParams() { set_reduce_dim_flags(); } + apply_post_kernel = true; + apply_division = false; + auto builder = [&](const ReduceKey& key) -> std::shared_ptr { std::shared_ptr post_kernel; #if defined(OPENVINO_ARCH_X86_64) @@ -2021,6 +2099,7 @@ void Reduce::createPrimitive() { jcp.dst_data_size = DnnlExtensionUtils::sizeOfDataType(jcp.dst_dt); jcp.layout = layout; jcp.reduce_mode = getAlgorithm(); + jcp.fuse_low_precision = fuse_low_precision; #if defined(OPENVINO_ARCH_X86_64) compile_post_kernel = true; @@ -2040,7 +2119,10 @@ void Reduce::createPrimitive() { updateLastInputDims(); } - create_reduce_kernel(reduce_kernel, jcp); + auto reduce_jcp = jcp; + reduce_jcp.dst_dt = fuse_low_precision ? DnnlExtensionUtils::IEPrecisionToDataType(intermediate_prec) : jcp.dst_dt; + jcp.dst_data_size = DnnlExtensionUtils::sizeOfDataType(reduce_jcp.dst_dt); + create_reduce_kernel(reduce_kernel, reduce_jcp); // set_use_aux_kernel being false means this is a dynamic case, and prepareParams() hasn't been invoked yet. // So set use_aux_kernel true if precision changes, in case ReduceDH_opt, ReduceCDW_opt or ReduceAll_opt @@ -2056,9 +2138,9 @@ void Reduce::createPrimitive() { // stage to reduce some dimensions, and an extra fp32-in-fp32-out aux kernel will be applied on the second // stage to reduce the rest dimensions. if (use_aux_kernel) { - aux_jcp = jcp; - aux_jcp.src_dt = jcp.dst_dt; - aux_jcp.src_data_size = jcp.dst_data_size; + aux_jcp = reduce_jcp; + aux_jcp.src_dt = reduce_jcp.dst_dt; + aux_jcp.src_data_size = reduce_jcp.dst_data_size; create_reduce_kernel(reduce_aux_kernel, aux_jcp); } } @@ -2093,7 +2175,7 @@ void Reduce::execute(dnnl::stream strm) { if (is_hybrid_layout) { dst_data = reinterpret_cast(prc_mem.get_data_handle()); } - reduce_type(src_data, dst_data, dst_size); + reduce_type(src_data, dst_data); } else if (aclExecPtr) { std::vector srcMemory; for (size_t i = 0; i < getParentEdges().size(); i++) { @@ -2114,8 +2196,7 @@ void Reduce::execute(dnnl::stream strm) { } } -void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_size) { - init_dst_data(out_ptr, dst_size); +void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr) { reduce_stride = IW; if (layout == ReduceLayoutType::reduce_ncsp || layout == ReduceLayoutType::reduce_nspc) { @@ -2141,6 +2222,9 @@ void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_siz } void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) { + output_info_reassign(out_ptr); + init_dst_data(out_ptr, dst_size); + if (ReduceN && !ReduceC && !ReduceD && !ReduceH && !ReduceW) { size_t IA = IC * ID * IH * IW; reduce_stride = IA; @@ -2348,24 +2432,29 @@ void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) { } } + output_info_restore(&out_ptr); reduce_kernel_post_process(out_ptr); } void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) { size_t ICB = div_up(IC, blk_size); size_t OCB = div_up(OC, blk_size); + output_info_reassign(out_ptr); + init_dst_data(out_ptr, dst_size); for (size_t ib = 0; ib < IB; ib++) { size_t ob = ReduceN ? 0 : ib; GET_PTR_N_BLK; if (!ReduceC && !ReduceD && ReduceH && ReduceW) { + if (!ReduceN || (ReduceN && ib == IB - 1)) { + apply_division = getAlgorithm() == Algorithm::ReduceMean && attr.get()->post_ops_.len() == 0; + apply_post_kernel = !apply_division; + } parallel_for2d(ICB, ID, [&](size_t icb, size_t id) { size_t ocb = icb, od = id; GET_PTR_NCD_BASE_PTR_N_BLK; reduce_kernel_process(in_ptr_ncd, out_ptr_ncd, IH * IW * blk_size); }); } else if (ReduceC && ReduceD && ReduceH && ReduceW) { - if (!ReduceAll_opt) { - reduce_kernel_process(in_ptr_n, out_ptr_n, ICB * ID * IH * IW * blk_size); - } else { + if (ReduceAll_opt) { // reduce parallelly // step1: !ReduceC && ReduceD && ReduceH && ReduceW size_t prc_size = ICB * blk_size * dst_data_size; @@ -2381,6 +2470,8 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) { reduce_kernel_reassign(); reduce_kernel_process(out_ptr_n, out_ptr_n_cp, ICB * blk_size); reduce_kernel_restore(); + } else { + reduce_kernel_process(in_ptr_n, out_ptr_n, ICB * ID * IH * IW * blk_size); } } else if (ReduceW) { for (size_t icb = 0; icb < ICB; icb++) { @@ -2419,12 +2510,17 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) { } } - reduce_kernel_post_process(out_ptr); + output_info_restore(&out_ptr); + if (apply_post_kernel) { + reduce_kernel_post_process(out_ptr); + } } void Reduce::reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr) { size_t ICB = div_up(IC, blk_size); size_t OCB = div_up(OC, blk_size); + output_info_reassign(out_ptr); + init_dst_data(out_ptr, dst_size); auto reduceSkipPadding = [&](const uint8_t *in_ptr_ncd, uint8_t *out_ptr_ncd, size_t ic) { size_t blk_valid_size = IC - ic; @@ -2503,11 +2599,13 @@ void Reduce::reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr) } } + output_info_restore(&out_ptr); reduce_kernel_post_process(out_ptr); } inline void Reduce::reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, size_t work_amount, size_t reduce_w, size_t work_batch, const int *tab_idx) { + const float divisor = apply_division ? static_cast(IB * IC * ID * IH * IW / (OB * OC * OD * OH * OW)) : 1; auto arg = jit_reduce_call_args(); arg.src = static_cast(in_p); arg.idx = tab_idx; @@ -2516,17 +2614,22 @@ inline void Reduce::reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, s arg.work_batch = work_batch; arg.reduce_w = reduce_w; arg.reduce_stride = reduce_stride; + arg.can_divide = apply_division ? 1 : 0; + arg.divisor = &divisor; (*reduce_kernel)(&arg); } inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) { + const uint8_t *in_ptr = fuse_low_precision ? static_cast(&intermediate_buf[0]) : nullptr; const size_t integerDivisor = IB * IC * ID * IH * IW / (OB * OC * OD * OH * OW); const float divisor = static_cast(integerDivisor); if (layout == ReduceLayoutType::reduce_ncsp) { parallel_for2d(OB, OC, [&](size_t ob, size_t oc) { + const uint8_t *in_p = in_ptr + (ob * OC + oc) * OD * OH * OW * intermediate_data_size; uint8_t *out_p = out_ptr + (ob * OC + oc) * OD * OH * OW * dst_data_size; auto arg = jit_reduce_post_call_args(); + arg.src = static_cast(in_p); arg.dst = static_cast(out_p); arg.oc_off = oc * sizeof(float); arg.channel_size = OC; @@ -2542,8 +2645,10 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) { OP *= OH; size_t work_amount = OB * OC * OD * OH * OW / OP; parallel_for(OP, [&](size_t op) { + const uint8_t *in_p = in_ptr + op * work_amount * intermediate_data_size; uint8_t *out_p = out_ptr + op * work_amount * dst_data_size; auto arg = jit_reduce_post_call_args(); + arg.src = static_cast(in_p); arg.dst = static_cast(out_p); arg.oc_off = 0; arg.channel_size = OW; // OW is related to nspc-ncsp dimension reinterpret @@ -2555,8 +2660,10 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) { } else { size_t OCB = div_up(OC, blk_size); parallel_for2d(OB, OCB, [&](size_t ob, size_t ocb) { + const uint8_t *in_p = in_ptr + (ob * OCB + ocb) * OD * OH * OW * blk_size * intermediate_data_size; uint8_t *out_p = out_ptr + (ob * OCB + ocb) * OD * OH * OW * blk_size * dst_data_size; auto arg = jit_reduce_post_call_args(); + arg.src = static_cast(in_p); arg.dst = static_cast(out_p); arg.reduce_c = ReduceC ? 1 : 0; arg.oc_off = ocb * blk_size * sizeof(float); @@ -2580,6 +2687,27 @@ inline void Reduce::reduce_kernel_restore() { } } +inline void Reduce::output_info_reassign(uint8_t *out_ptr) { + if (fuse_low_precision) { + tmp_ptr = out_ptr; + out_ptr = static_cast(&intermediate_buf[0]); + tmp_prec = output_prec; + output_prec = intermediate_prec; + tmp_data_size = dst_data_size; + dst_data_size = intermediate_data_size; + tmp_size = dst_size; + dst_size = intermediate_size; + } +} +inline void Reduce::output_info_restore(uint8_t **out_ptr) { + if (fuse_low_precision) { + *out_ptr = tmp_ptr; + output_prec = tmp_prec; + dst_data_size = tmp_data_size; + dst_size = tmp_size; + } +} + void Reduce::nspc2ncsp(uint8_t *proc_ptr, uint8_t *out_ptr) { // dimension reinterpret after nspc reusing routine reduce_PLN // demote -- nspc -- ncsp @@ -2795,12 +2923,19 @@ inline void Reduce::create_hybrid_working_memory() { } inline void Reduce::create_opt_working_memory() { + if (fuse_low_precision) { + intermediate_size = dst_size * sizeof(float) / dst_data_size; + if (intermediate_size > intermediate_buf.size()) { + intermediate_buf.resize(intermediate_size); + } + } + ReduceDH_opt = layout == ReduceLayoutType::reduce_nspc && support_split && !ReduceC && ReduceD && ReduceH && !ReduceW && IC == 1 && ID > 1; if (ReduceDH_opt) { PD = ID; PW = IW / blk_size * blk_size; - prc_data_size = dst_data_size; + prc_data_size = intermediate_data_size; prc_size = PD * PW * prc_data_size; if (prc_size > vec_reduceDH_prc.size()) { vec_reduceDH_prc.resize(prc_size); @@ -2813,7 +2948,7 @@ inline void Reduce::create_opt_working_memory() { if (ReduceCDW_opt) { PH = IH; PW = IW; - prc_data_size = dst_data_size; + prc_data_size = intermediate_data_size; prc_size = PH * PW * prc_data_size; if (prc_size > vec_reduceCDW_prc.size()) { vec_reduceCDW_prc.resize(prc_size); @@ -3168,16 +3303,6 @@ bool Reduce::canFuse(const NodePtr& node) const { return false; } - // In jit mode we use the output memory as an intermediate accumulator for certain reduce modes. - // If the post ops node has a lower precision for such modes, post ops fusing won't be supposted, in order to avoid accuracy loss. - if (output_prec == Precision::FP32 && - !node->getOriginalOutputPrecisions().empty() && node->getOriginalOutputPrecisionAtPort(0) != Precision::FP32) { - if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr && - algorithm != Algorithm::ReduceMin && algorithm != Algorithm::ReduceMax) { - return false; - } - } - return canFuseSimpleOperation(node); } diff --git a/src/plugins/intel_cpu/src/nodes/reduce.h b/src/plugins/intel_cpu/src/nodes/reduce.h index d3c8de102eb..2f07cb196a7 100644 --- a/src/plugins/intel_cpu/src/nodes/reduce.h +++ b/src/plugins/intel_cpu/src/nodes/reduce.h @@ -24,6 +24,7 @@ enum ReduceLayoutType { struct jit_reduce_config_params { ReduceLayoutType layout; Algorithm reduce_mode; + bool fuse_low_precision; dnnl::memory::data_type src_dt; dnnl::memory::data_type dst_dt; int src_data_size; @@ -38,6 +39,8 @@ struct jit_reduce_call_args { size_t work_batch; size_t reduce_w = 2; // only used in planar layout [1: reduce width dimension] [0: reduce other dimension] [other value: N/A] size_t reduce_stride; // only used in planar layout while reducing dimensions except for width + size_t can_divide; // if apply division in reduce_kernel [1: Yes] [0: No] + const float *divisor; // mean = sum / divisor }; struct jit_reduce_post_call_args { @@ -105,7 +108,7 @@ public: static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: - void reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_size); + void reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr); void reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr); void reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr); void reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr); @@ -114,6 +117,8 @@ private: inline void reduce_kernel_post_process(uint8_t *out_ptr); inline void reduce_kernel_reassign(); inline void reduce_kernel_restore(); + inline void output_info_reassign(uint8_t *out_ptr); + inline void output_info_restore(uint8_t **out_ptr); inline void init_dst_data(uint8_t *out_ptr, size_t dst_size); inline void create_hybrid_working_memory(); inline void create_opt_working_memory(); @@ -131,8 +136,6 @@ private: bool canApplyJIT(const InferenceEngine::Precision &input_prec, const InferenceEngine::Precision &output_prec) const; size_t blk_size; - size_t dst_size; - size_t prc_size; static const size_t REDUCE_DATA = 0; static const size_t REDUCE_INDEXES = 1; bool jit_beyond_5D = false; @@ -140,6 +143,9 @@ private: bool keep_dims = true; bool is_hybrid_layout = false; bool compile_post_kernel = true; + bool apply_post_kernel = true; + bool apply_division = false; + bool fuse_low_precision = false; bool support_split = false; bool precision_change = false; bool ReduceAll_opt = false; @@ -151,14 +157,17 @@ private: size_t IB, IC, ID, IH, IW; size_t OB, OC, OD, OH, OW; size_t PD, PH, PW; - size_t src_data_size, dst_data_size, prc_data_size; + size_t src_data_size, dst_data_size, prc_data_size, intermediate_data_size, tmp_data_size; + size_t dst_size, prc_size, intermediate_size, tmp_size; size_t reduce_stride; + uint8_t *tmp_ptr; ReduceLayoutType layout; - InferenceEngine::Precision input_prec, output_prec; + InferenceEngine::Precision input_prec, output_prec, intermediate_prec, tmp_prec; InferenceEngine::SizeVector src_dims; InferenceEngine::SizeVector process_dst_dims; InferenceEngine::SizeVector axes_for_reduction; std::vector raw_axes; + std::vector intermediate_buf; jit_reduce_config_params jcp; jit_reduce_config_params aux_jcp;