[CPU] Fix performance issue for some cases of Reduce node (#11456)
This commit is contained in:
parent
af6c2b0671
commit
7c1949421f
@ -89,6 +89,7 @@ size_t ReduceKey::hash() const {
|
||||
size_t seed = 0;
|
||||
seed = hash_combine(seed, jcp.layout);
|
||||
seed = hash_combine(seed, jcp.reduce_mode);
|
||||
seed = hash_combine(seed, jcp.fuse_low_precision);
|
||||
seed = hash_combine(seed, jcp.src_dt);
|
||||
seed = hash_combine(seed, jcp.dst_dt);
|
||||
seed = get_post_op_hash(seed, *postOps.get());
|
||||
@ -98,6 +99,7 @@ size_t ReduceKey::hash() const {
|
||||
|
||||
bool ReduceKey::operator==(const ReduceKey &rhs) const {
|
||||
return jcp.layout == rhs.jcp.layout && jcp.reduce_mode == rhs.jcp.reduce_mode &&
|
||||
jcp.fuse_low_precision == rhs.jcp.fuse_low_precision &&
|
||||
jcp.src_dt == rhs.jcp.src_dt && jcp.dst_dt == rhs.jcp.dst_dt && *postOps.get() == *rhs.postOps.get();
|
||||
}
|
||||
} // namespace
|
||||
@ -192,6 +194,8 @@ private:
|
||||
|
||||
Xbyak::Reg64 reg_src_aux = rax;
|
||||
Xbyak::Reg64 reg_work_batch_aux = rbx;
|
||||
Xbyak::Reg64 reg_can_divide = rbp;
|
||||
Xbyak::Reg64 reg_divisor = reg_can_divide;
|
||||
|
||||
Vmm vmm_aux = Vmm(0);
|
||||
Xmm xmm_aux = Xmm(0);
|
||||
@ -302,6 +306,22 @@ private:
|
||||
// reduce
|
||||
reduce_kernel();
|
||||
|
||||
if (jcp_.reduce_mode == Algorithm::ReduceMean) {
|
||||
Xbyak::Label reduce_divide_end_label;
|
||||
mov(reg_can_divide, ptr[reg_params + GET_OFF(can_divide)]);
|
||||
cmp(reg_can_divide, 0);
|
||||
je(reduce_divide_end_label, T_NEAR);
|
||||
{
|
||||
mov(reg_divisor, ptr[reg_params + GET_OFF(divisor)]);
|
||||
uni_vbroadcastss(vmm_aux, ptr[reg_divisor]);
|
||||
uni_vdivps(vmm_dst, vmm_dst, vmm_aux);
|
||||
if (isa == cpu::x64::sse41) {
|
||||
uni_vdivps(vmm_dst_aux, vmm_dst_aux, vmm_aux);
|
||||
}
|
||||
}
|
||||
L(reduce_divide_end_label);
|
||||
}
|
||||
|
||||
// store
|
||||
store_dst_vector();
|
||||
|
||||
@ -1124,14 +1144,19 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
|
||||
this->preamble();
|
||||
|
||||
planar_layout = jcp_.layout == ReduceLayoutType::reduce_ncsp || jcp_.layout == ReduceLayoutType::reduce_nspc;
|
||||
post_reduce = jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean ||
|
||||
jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp;
|
||||
post_ops_fusing = attr_.post_ops_.len() != 0;
|
||||
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]);
|
||||
mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]);
|
||||
mov(reg_channel_size, ptr[reg_params + GET_OFF_POST(channel_size)]);
|
||||
mov(reg_divisor, ptr[reg_params + GET_OFF_POST(divisor)]);
|
||||
if (jcp_.fuse_low_precision)
|
||||
mov(reg_src, ptr[reg_params + GET_OFF_POST(src)]);
|
||||
if (!planar_layout)
|
||||
mov(reg_reduce_c, ptr[reg_params + GET_OFF_POST(reduce_c)]);
|
||||
if (attr_.post_ops_.len() != 0) {
|
||||
if (post_ops_fusing) {
|
||||
mov(reg_post_ops_data, ptr[reg_params + GET_OFF_POST(post_op_data)]);
|
||||
mov(reg_oc_off, ptr[reg_params + GET_OFF_POST(oc_off)]);
|
||||
}
|
||||
@ -1141,7 +1166,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
|
||||
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_blocked) {
|
||||
reduce_post_main();
|
||||
} else if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0) {
|
||||
} else if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing) {
|
||||
// the tail of channel dimension should always be concerned during post ops fusing for nspc layout
|
||||
Xbyak::Label reduce_nspc_loop_label;
|
||||
Xbyak::Label reduce_nspc_loop_end_label;
|
||||
@ -1183,7 +1208,10 @@ private:
|
||||
Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||
size_t vlen = cpu_isa_traits<isa>::vlen;
|
||||
bool planar_layout = false;
|
||||
bool post_reduce = true;
|
||||
bool post_ops_fusing = false;
|
||||
|
||||
Xbyak::Reg64 reg_src = rbp;
|
||||
Xbyak::Reg64 reg_dst = r8;
|
||||
Xbyak::Reg64 reg_work_amount = r9;
|
||||
Xbyak::Reg64 reg_total_work_amount = r10;
|
||||
@ -1246,9 +1274,9 @@ private:
|
||||
jl(reduce_loop_end_label, T_NEAR);
|
||||
|
||||
// load
|
||||
load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt);
|
||||
wrap_load_vector(vmm_dst, 0);
|
||||
if (isa == cpu::x64::sse41)
|
||||
load_vector(vmm_dst_aux, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt);
|
||||
wrap_load_vector(vmm_dst_aux, 4);
|
||||
|
||||
// reduce and store
|
||||
horiz_reduce_store(vmm_dst, jcp_.dst_dt);
|
||||
@ -1256,22 +1284,27 @@ private:
|
||||
horiz_reduce_store(vmm_dst_aux, jcp_.dst_dt, true);
|
||||
|
||||
add(reg_dst, step * jcp_.dst_data_size);
|
||||
if (jcp_.fuse_low_precision)
|
||||
add(reg_src, step * sizeof(float));
|
||||
sub(reg_work_amount, step);
|
||||
|
||||
jmp(reduce_loop_label, T_NEAR);
|
||||
}
|
||||
L(reduce_loop_end_label);
|
||||
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]);
|
||||
mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]);
|
||||
if (post_reduce || post_ops_fusing) {
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF_POST(dst)]);
|
||||
if (jcp_.fuse_low_precision)
|
||||
mov(reg_src, ptr[reg_params + GET_OFF_POST(src)]);
|
||||
mov(reg_work_amount, ptr[reg_params + GET_OFF_POST(work_amount)]);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce map for value in dst memory
|
||||
// cases: [ReduceL2] [ReduceLogSum] [ReduceLogSumExp] [ReduceMean]
|
||||
L(reduce_map_label);
|
||||
{
|
||||
if (jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean ||
|
||||
jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp) {
|
||||
if (post_reduce) {
|
||||
if (jcp_.reduce_mode == Algorithm::ReduceMean)
|
||||
uni_vbroadcastss(vmm_aux, ptr[reg_divisor]);
|
||||
|
||||
@ -1284,16 +1317,16 @@ private:
|
||||
cmp(reg_work_amount, step);
|
||||
jl(reduce_loop_end_label, T_NEAR);
|
||||
|
||||
load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt);
|
||||
wrap_load_vector(vmm_dst, 0);
|
||||
reduce_map_kernel(vmm_dst);
|
||||
if (attr_.post_ops_.len() != 0)
|
||||
if (post_ops_fusing)
|
||||
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
|
||||
store_vector(ptr[reg_dst], vmm_dst, jcp_.dst_dt);
|
||||
|
||||
if (isa == cpu::x64::sse41) {
|
||||
load_vector(vmm_dst, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt);
|
||||
wrap_load_vector(vmm_dst, 4);
|
||||
reduce_map_kernel(vmm_dst);
|
||||
if (attr_.post_ops_.len() != 0) {
|
||||
if (post_ops_fusing) {
|
||||
if (jcp_.layout != ReduceLayoutType::reduce_ncsp)
|
||||
add(reg_oc_off, 4 * sizeof(float));
|
||||
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
|
||||
@ -1304,7 +1337,9 @@ private:
|
||||
}
|
||||
|
||||
add(reg_dst, step * jcp_.dst_data_size);
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
|
||||
if (jcp_.fuse_low_precision)
|
||||
add(reg_src, step * sizeof(float));
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
|
||||
add(reg_oc_off, step * sizeof(float));
|
||||
sub(reg_work_amount, step);
|
||||
|
||||
@ -1312,7 +1347,7 @@ private:
|
||||
}
|
||||
L(reduce_loop_end_label);
|
||||
} else {
|
||||
if (attr_.post_ops_.len() != 0) {
|
||||
if (post_ops_fusing) {
|
||||
Xbyak::Label reduce_loop_label;
|
||||
Xbyak::Label reduce_loop_end_label;
|
||||
|
||||
@ -1322,12 +1357,12 @@ private:
|
||||
cmp(reg_work_amount, step);
|
||||
jl(reduce_loop_end_label, T_NEAR);
|
||||
|
||||
load_vector(vmm_dst, ptr[reg_dst], jcp_.dst_dt);
|
||||
wrap_load_vector(vmm_dst, 0);
|
||||
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
|
||||
store_vector(ptr[reg_dst], vmm_dst, jcp_.dst_dt);
|
||||
|
||||
if (isa == cpu::x64::sse41) {
|
||||
load_vector(vmm_dst, ptr[reg_dst + 4 * jcp_.dst_data_size], jcp_.dst_dt);
|
||||
wrap_load_vector(vmm_dst, 4);
|
||||
if (jcp_.layout != ReduceLayoutType::reduce_ncsp)
|
||||
add(reg_oc_off, 4 * sizeof(float));
|
||||
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
|
||||
@ -1337,7 +1372,9 @@ private:
|
||||
}
|
||||
|
||||
add(reg_dst, step * jcp_.dst_data_size);
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
|
||||
if (jcp_.fuse_low_precision)
|
||||
add(reg_src, step * sizeof(float));
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
|
||||
add(reg_oc_off, step * sizeof(float));
|
||||
sub(reg_work_amount, step);
|
||||
|
||||
@ -1352,8 +1389,7 @@ private:
|
||||
inline void reduce_post_tail() {
|
||||
// reduce map for tail in dst memory
|
||||
// cases: [ReduceL2] [ReduceLogSum] [ReduceLogSumExp] [ReduceMean] in planar layout
|
||||
if (jcp_.reduce_mode == Algorithm::ReduceL2 || jcp_.reduce_mode == Algorithm::ReduceMean ||
|
||||
jcp_.reduce_mode == Algorithm::ReduceLogSum || jcp_.reduce_mode == Algorithm::ReduceLogSumExp) {
|
||||
if (post_reduce) {
|
||||
if (jcp_.reduce_mode == Algorithm::ReduceMean)
|
||||
uni_vbroadcastss(xmm_aux, ptr[reg_divisor]);
|
||||
|
||||
@ -1367,18 +1403,20 @@ private:
|
||||
jl(reduce_loop_end_label, T_NEAR);
|
||||
|
||||
// load
|
||||
load_scalar(xmm_dst, ptr[reg_dst], jcp_.dst_dt);
|
||||
wrap_load_scalar(xmm_dst, 0);
|
||||
|
||||
// reduce
|
||||
reduce_map_kernel_scalar(xmm_dst);
|
||||
|
||||
// store
|
||||
if (attr_.post_ops_.len() != 0)
|
||||
if (post_ops_fusing)
|
||||
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
|
||||
store_scalar(ptr[reg_dst], xmm_dst, jcp_.dst_dt);
|
||||
|
||||
add(reg_dst, step * jcp_.dst_data_size);
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
|
||||
if (jcp_.fuse_low_precision)
|
||||
add(reg_src, step * sizeof(float));
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
|
||||
add(reg_oc_off, step * sizeof(float));
|
||||
sub(reg_work_amount, step);
|
||||
|
||||
@ -1386,7 +1424,7 @@ private:
|
||||
}
|
||||
L(reduce_loop_end_label);
|
||||
} else {
|
||||
if (attr_.post_ops_.len() != 0) {
|
||||
if (post_ops_fusing) {
|
||||
Xbyak::Label reduce_loop_label;
|
||||
Xbyak::Label reduce_loop_end_label;
|
||||
|
||||
@ -1397,14 +1435,16 @@ private:
|
||||
jl(reduce_loop_end_label, T_NEAR);
|
||||
|
||||
// load
|
||||
load_scalar(xmm_dst, ptr[reg_dst], jcp_.dst_dt);
|
||||
wrap_load_scalar(xmm_dst, 0);
|
||||
|
||||
// store
|
||||
apply_post_ops(jcp_.dst_dt, jcp_.layout == ReduceLayoutType::reduce_ncsp);
|
||||
store_scalar(ptr[reg_dst], xmm_dst, jcp_.dst_dt);
|
||||
|
||||
add(reg_dst, step * jcp_.dst_data_size);
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && attr_.post_ops_.len() != 0)
|
||||
if (jcp_.fuse_low_precision)
|
||||
add(reg_src, step * sizeof(float));
|
||||
if (jcp_.layout == ReduceLayoutType::reduce_nspc && post_ops_fusing)
|
||||
add(reg_oc_off, step * sizeof(float));
|
||||
sub(reg_work_amount, step);
|
||||
|
||||
@ -1476,6 +1516,20 @@ private:
|
||||
log_injector->compute_vector_range(xmm_dst.getIdx(), xmm_dst.getIdx() + 1);
|
||||
}
|
||||
|
||||
inline void wrap_load_vector(Vmm vmm_val, size_t offset) {
|
||||
if (jcp_.fuse_low_precision)
|
||||
load_vector(vmm_val, ptr[reg_src + offset * sizeof(float)], memory::data_type::f32);
|
||||
else
|
||||
load_vector(vmm_val, ptr[reg_dst + offset * jcp_.dst_data_size], jcp_.dst_dt);
|
||||
}
|
||||
|
||||
inline void wrap_load_scalar(Xmm xmm_val, size_t offset) {
|
||||
if (jcp_.fuse_low_precision)
|
||||
load_scalar(xmm_val, ptr[reg_src + offset * sizeof(float)], memory::data_type::f32);
|
||||
else
|
||||
load_scalar(xmm_val, ptr[reg_dst + offset * jcp_.dst_data_size], jcp_.dst_dt);
|
||||
}
|
||||
|
||||
inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
|
||||
switch (src_dt) {
|
||||
case memory::data_type::f32:
|
||||
@ -1636,11 +1690,19 @@ private:
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
|
||||
uni_vmovhlps(xmm_aux3, xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
|
||||
if (load_embedded) {
|
||||
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
|
||||
horiz_ps(xmm_dst, xmm_aux3);
|
||||
if (jcp_.fuse_low_precision && (post_reduce || post_ops_fusing)) {
|
||||
if (load_embedded) {
|
||||
load_scalar(xmm_aux3, ptr[reg_src], memory::data_type::f32);
|
||||
horiz_ps(xmm_dst, xmm_aux3);
|
||||
}
|
||||
store_scalar(ptr[reg_src], xmm_dst, memory::data_type::f32);
|
||||
} else {
|
||||
if (load_embedded) {
|
||||
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
|
||||
horiz_ps(xmm_dst, xmm_aux3);
|
||||
}
|
||||
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
|
||||
}
|
||||
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
|
||||
}
|
||||
|
||||
inline void horiz_ps(const Xmm& xmm, const Operand& op) {
|
||||
@ -1762,6 +1824,7 @@ Reduce::Reduce(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
|
||||
raw_axes = reduceConst->cast_vector<int>();
|
||||
}
|
||||
set_use_aux_kernel = false;
|
||||
fuse_low_precision = false;
|
||||
vec_reduceDH_prc.clear();
|
||||
vec_reduceCDW_prc.clear();
|
||||
setJITBeyond5D();
|
||||
@ -1800,7 +1863,17 @@ void Reduce::initSupportedPrimitiveDescriptors() {
|
||||
output_prec = getOriginalOutputPrecisionAtPort(0);
|
||||
|
||||
if (!fusedWith.empty()) {
|
||||
output_prec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
|
||||
// In jit mode we use the output memory as an intermediate accumulator for certain reduce modes.
|
||||
// If the post ops node has a lower precision for such modes, working buffer with original precision is needed,
|
||||
// in order to avoid accuracy loss.
|
||||
auto fused_prec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
|
||||
if (output_prec == Precision::FP32 && fused_prec != Precision::FP32) {
|
||||
if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr &&
|
||||
algorithm != Algorithm::ReduceMin && algorithm != Algorithm::ReduceMax) {
|
||||
fuse_low_precision = true;
|
||||
}
|
||||
}
|
||||
output_prec = fused_prec;
|
||||
}
|
||||
|
||||
jit_mode = canApplyJIT(input_prec, output_prec);
|
||||
@ -1818,12 +1891,14 @@ void Reduce::initSupportedPrimitiveDescriptors() {
|
||||
}
|
||||
}
|
||||
|
||||
precision_change = input_prec != output_prec;
|
||||
intermediate_prec = fuse_low_precision ? Precision(Precision::FP32) : output_prec;
|
||||
precision_change = input_prec != intermediate_prec;
|
||||
support_split = algorithm != Algorithm::ReduceL2 && algorithm != Algorithm::ReduceLogSumExp &&
|
||||
algorithm != Algorithm::ReduceSumSquare;
|
||||
|
||||
src_data_size = input_prec.size();
|
||||
dst_data_size = output_prec.size();
|
||||
intermediate_data_size = intermediate_prec.size();
|
||||
|
||||
NodeConfig config;
|
||||
config.inConfs.resize(2);
|
||||
@ -1950,6 +2025,9 @@ void Reduce::prepareParams() {
|
||||
set_reduce_dim_flags();
|
||||
}
|
||||
|
||||
apply_post_kernel = true;
|
||||
apply_division = false;
|
||||
|
||||
auto builder = [&](const ReduceKey& key) -> std::shared_ptr<jit_uni_reduce_post_kernel> {
|
||||
std::shared_ptr<jit_uni_reduce_post_kernel> post_kernel;
|
||||
#if defined(OPENVINO_ARCH_X86_64)
|
||||
@ -2021,6 +2099,7 @@ void Reduce::createPrimitive() {
|
||||
jcp.dst_data_size = DnnlExtensionUtils::sizeOfDataType(jcp.dst_dt);
|
||||
jcp.layout = layout;
|
||||
jcp.reduce_mode = getAlgorithm();
|
||||
jcp.fuse_low_precision = fuse_low_precision;
|
||||
|
||||
#if defined(OPENVINO_ARCH_X86_64)
|
||||
compile_post_kernel = true;
|
||||
@ -2040,7 +2119,10 @@ void Reduce::createPrimitive() {
|
||||
updateLastInputDims();
|
||||
}
|
||||
|
||||
create_reduce_kernel(reduce_kernel, jcp);
|
||||
auto reduce_jcp = jcp;
|
||||
reduce_jcp.dst_dt = fuse_low_precision ? DnnlExtensionUtils::IEPrecisionToDataType(intermediate_prec) : jcp.dst_dt;
|
||||
jcp.dst_data_size = DnnlExtensionUtils::sizeOfDataType(reduce_jcp.dst_dt);
|
||||
create_reduce_kernel(reduce_kernel, reduce_jcp);
|
||||
|
||||
// set_use_aux_kernel being false means this is a dynamic case, and prepareParams() hasn't been invoked yet.
|
||||
// So set use_aux_kernel true if precision changes, in case ReduceDH_opt, ReduceCDW_opt or ReduceAll_opt
|
||||
@ -2056,9 +2138,9 @@ void Reduce::createPrimitive() {
|
||||
// stage to reduce some dimensions, and an extra fp32-in-fp32-out aux kernel will be applied on the second
|
||||
// stage to reduce the rest dimensions.
|
||||
if (use_aux_kernel) {
|
||||
aux_jcp = jcp;
|
||||
aux_jcp.src_dt = jcp.dst_dt;
|
||||
aux_jcp.src_data_size = jcp.dst_data_size;
|
||||
aux_jcp = reduce_jcp;
|
||||
aux_jcp.src_dt = reduce_jcp.dst_dt;
|
||||
aux_jcp.src_data_size = reduce_jcp.dst_data_size;
|
||||
create_reduce_kernel(reduce_aux_kernel, aux_jcp);
|
||||
}
|
||||
}
|
||||
@ -2093,7 +2175,7 @@ void Reduce::execute(dnnl::stream strm) {
|
||||
if (is_hybrid_layout) {
|
||||
dst_data = reinterpret_cast<uint8_t *>(prc_mem.get_data_handle());
|
||||
}
|
||||
reduce_type(src_data, dst_data, dst_size);
|
||||
reduce_type(src_data, dst_data);
|
||||
} else if (aclExecPtr) {
|
||||
std::vector<MemoryCPtr> srcMemory;
|
||||
for (size_t i = 0; i < getParentEdges().size(); i++) {
|
||||
@ -2114,8 +2196,7 @@ void Reduce::execute(dnnl::stream strm) {
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_size) {
|
||||
init_dst_data(out_ptr, dst_size);
|
||||
void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
reduce_stride = IW;
|
||||
|
||||
if (layout == ReduceLayoutType::reduce_ncsp || layout == ReduceLayoutType::reduce_nspc) {
|
||||
@ -2141,6 +2222,9 @@ void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_siz
|
||||
}
|
||||
|
||||
void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
output_info_reassign(out_ptr);
|
||||
init_dst_data(out_ptr, dst_size);
|
||||
|
||||
if (ReduceN && !ReduceC && !ReduceD && !ReduceH && !ReduceW) {
|
||||
size_t IA = IC * ID * IH * IW;
|
||||
reduce_stride = IA;
|
||||
@ -2348,24 +2432,29 @@ void Reduce::reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
}
|
||||
}
|
||||
|
||||
output_info_restore(&out_ptr);
|
||||
reduce_kernel_post_process(out_ptr);
|
||||
}
|
||||
|
||||
void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
size_t ICB = div_up(IC, blk_size);
|
||||
size_t OCB = div_up(OC, blk_size);
|
||||
output_info_reassign(out_ptr);
|
||||
init_dst_data(out_ptr, dst_size);
|
||||
|
||||
for (size_t ib = 0; ib < IB; ib++) {
|
||||
size_t ob = ReduceN ? 0 : ib; GET_PTR_N_BLK;
|
||||
if (!ReduceC && !ReduceD && ReduceH && ReduceW) {
|
||||
if (!ReduceN || (ReduceN && ib == IB - 1)) {
|
||||
apply_division = getAlgorithm() == Algorithm::ReduceMean && attr.get()->post_ops_.len() == 0;
|
||||
apply_post_kernel = !apply_division;
|
||||
}
|
||||
parallel_for2d(ICB, ID, [&](size_t icb, size_t id) {
|
||||
size_t ocb = icb, od = id; GET_PTR_NCD_BASE_PTR_N_BLK;
|
||||
reduce_kernel_process(in_ptr_ncd, out_ptr_ncd, IH * IW * blk_size);
|
||||
});
|
||||
} else if (ReduceC && ReduceD && ReduceH && ReduceW) {
|
||||
if (!ReduceAll_opt) {
|
||||
reduce_kernel_process(in_ptr_n, out_ptr_n, ICB * ID * IH * IW * blk_size);
|
||||
} else {
|
||||
if (ReduceAll_opt) {
|
||||
// reduce parallelly
|
||||
// step1: !ReduceC && ReduceD && ReduceH && ReduceW
|
||||
size_t prc_size = ICB * blk_size * dst_data_size;
|
||||
@ -2381,6 +2470,8 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
reduce_kernel_reassign();
|
||||
reduce_kernel_process(out_ptr_n, out_ptr_n_cp, ICB * blk_size);
|
||||
reduce_kernel_restore();
|
||||
} else {
|
||||
reduce_kernel_process(in_ptr_n, out_ptr_n, ICB * ID * IH * IW * blk_size);
|
||||
}
|
||||
} else if (ReduceW) {
|
||||
for (size_t icb = 0; icb < ICB; icb++) {
|
||||
@ -2419,12 +2510,17 @@ void Reduce::reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
}
|
||||
}
|
||||
|
||||
reduce_kernel_post_process(out_ptr);
|
||||
output_info_restore(&out_ptr);
|
||||
if (apply_post_kernel) {
|
||||
reduce_kernel_post_process(out_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr) {
|
||||
size_t ICB = div_up(IC, blk_size);
|
||||
size_t OCB = div_up(OC, blk_size);
|
||||
output_info_reassign(out_ptr);
|
||||
init_dst_data(out_ptr, dst_size);
|
||||
|
||||
auto reduceSkipPadding = [&](const uint8_t *in_ptr_ncd, uint8_t *out_ptr_ncd, size_t ic) {
|
||||
size_t blk_valid_size = IC - ic;
|
||||
@ -2503,11 +2599,13 @@ void Reduce::reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr)
|
||||
}
|
||||
}
|
||||
|
||||
output_info_restore(&out_ptr);
|
||||
reduce_kernel_post_process(out_ptr);
|
||||
}
|
||||
|
||||
inline void Reduce::reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, size_t work_amount,
|
||||
size_t reduce_w, size_t work_batch, const int *tab_idx) {
|
||||
const float divisor = apply_division ? static_cast<float>(IB * IC * ID * IH * IW / (OB * OC * OD * OH * OW)) : 1;
|
||||
auto arg = jit_reduce_call_args();
|
||||
arg.src = static_cast<const void *>(in_p);
|
||||
arg.idx = tab_idx;
|
||||
@ -2516,17 +2614,22 @@ inline void Reduce::reduce_kernel_process(const uint8_t *in_p, uint8_t *out_p, s
|
||||
arg.work_batch = work_batch;
|
||||
arg.reduce_w = reduce_w;
|
||||
arg.reduce_stride = reduce_stride;
|
||||
arg.can_divide = apply_division ? 1 : 0;
|
||||
arg.divisor = &divisor;
|
||||
|
||||
(*reduce_kernel)(&arg);
|
||||
}
|
||||
|
||||
inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
|
||||
const uint8_t *in_ptr = fuse_low_precision ? static_cast<uint8_t *>(&intermediate_buf[0]) : nullptr;
|
||||
const size_t integerDivisor = IB * IC * ID * IH * IW / (OB * OC * OD * OH * OW);
|
||||
const float divisor = static_cast<float>(integerDivisor);
|
||||
if (layout == ReduceLayoutType::reduce_ncsp) {
|
||||
parallel_for2d(OB, OC, [&](size_t ob, size_t oc) {
|
||||
const uint8_t *in_p = in_ptr + (ob * OC + oc) * OD * OH * OW * intermediate_data_size;
|
||||
uint8_t *out_p = out_ptr + (ob * OC + oc) * OD * OH * OW * dst_data_size;
|
||||
auto arg = jit_reduce_post_call_args();
|
||||
arg.src = static_cast<const void *>(in_p);
|
||||
arg.dst = static_cast<void *>(out_p);
|
||||
arg.oc_off = oc * sizeof(float);
|
||||
arg.channel_size = OC;
|
||||
@ -2542,8 +2645,10 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
|
||||
OP *= OH;
|
||||
size_t work_amount = OB * OC * OD * OH * OW / OP;
|
||||
parallel_for(OP, [&](size_t op) {
|
||||
const uint8_t *in_p = in_ptr + op * work_amount * intermediate_data_size;
|
||||
uint8_t *out_p = out_ptr + op * work_amount * dst_data_size;
|
||||
auto arg = jit_reduce_post_call_args();
|
||||
arg.src = static_cast<const void *>(in_p);
|
||||
arg.dst = static_cast<void *>(out_p);
|
||||
arg.oc_off = 0;
|
||||
arg.channel_size = OW; // OW is related to nspc-ncsp dimension reinterpret
|
||||
@ -2555,8 +2660,10 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) {
|
||||
} else {
|
||||
size_t OCB = div_up(OC, blk_size);
|
||||
parallel_for2d(OB, OCB, [&](size_t ob, size_t ocb) {
|
||||
const uint8_t *in_p = in_ptr + (ob * OCB + ocb) * OD * OH * OW * blk_size * intermediate_data_size;
|
||||
uint8_t *out_p = out_ptr + (ob * OCB + ocb) * OD * OH * OW * blk_size * dst_data_size;
|
||||
auto arg = jit_reduce_post_call_args();
|
||||
arg.src = static_cast<const void *>(in_p);
|
||||
arg.dst = static_cast<void *>(out_p);
|
||||
arg.reduce_c = ReduceC ? 1 : 0;
|
||||
arg.oc_off = ocb * blk_size * sizeof(float);
|
||||
@ -2580,6 +2687,27 @@ inline void Reduce::reduce_kernel_restore() {
|
||||
}
|
||||
}
|
||||
|
||||
inline void Reduce::output_info_reassign(uint8_t *out_ptr) {
|
||||
if (fuse_low_precision) {
|
||||
tmp_ptr = out_ptr;
|
||||
out_ptr = static_cast<uint8_t *>(&intermediate_buf[0]);
|
||||
tmp_prec = output_prec;
|
||||
output_prec = intermediate_prec;
|
||||
tmp_data_size = dst_data_size;
|
||||
dst_data_size = intermediate_data_size;
|
||||
tmp_size = dst_size;
|
||||
dst_size = intermediate_size;
|
||||
}
|
||||
}
|
||||
inline void Reduce::output_info_restore(uint8_t **out_ptr) {
|
||||
if (fuse_low_precision) {
|
||||
*out_ptr = tmp_ptr;
|
||||
output_prec = tmp_prec;
|
||||
dst_data_size = tmp_data_size;
|
||||
dst_size = tmp_size;
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::nspc2ncsp(uint8_t *proc_ptr, uint8_t *out_ptr) {
|
||||
// dimension reinterpret after nspc reusing routine reduce_PLN
|
||||
// demote -- nspc -- ncsp
|
||||
@ -2795,12 +2923,19 @@ inline void Reduce::create_hybrid_working_memory() {
|
||||
}
|
||||
|
||||
inline void Reduce::create_opt_working_memory() {
|
||||
if (fuse_low_precision) {
|
||||
intermediate_size = dst_size * sizeof(float) / dst_data_size;
|
||||
if (intermediate_size > intermediate_buf.size()) {
|
||||
intermediate_buf.resize(intermediate_size);
|
||||
}
|
||||
}
|
||||
|
||||
ReduceDH_opt = layout == ReduceLayoutType::reduce_nspc && support_split &&
|
||||
!ReduceC && ReduceD && ReduceH && !ReduceW && IC == 1 && ID > 1;
|
||||
if (ReduceDH_opt) {
|
||||
PD = ID;
|
||||
PW = IW / blk_size * blk_size;
|
||||
prc_data_size = dst_data_size;
|
||||
prc_data_size = intermediate_data_size;
|
||||
prc_size = PD * PW * prc_data_size;
|
||||
if (prc_size > vec_reduceDH_prc.size()) {
|
||||
vec_reduceDH_prc.resize(prc_size);
|
||||
@ -2813,7 +2948,7 @@ inline void Reduce::create_opt_working_memory() {
|
||||
if (ReduceCDW_opt) {
|
||||
PH = IH;
|
||||
PW = IW;
|
||||
prc_data_size = dst_data_size;
|
||||
prc_data_size = intermediate_data_size;
|
||||
prc_size = PH * PW * prc_data_size;
|
||||
if (prc_size > vec_reduceCDW_prc.size()) {
|
||||
vec_reduceCDW_prc.resize(prc_size);
|
||||
@ -3168,16 +3303,6 @@ bool Reduce::canFuse(const NodePtr& node) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// In jit mode we use the output memory as an intermediate accumulator for certain reduce modes.
|
||||
// If the post ops node has a lower precision for such modes, post ops fusing won't be supposted, in order to avoid accuracy loss.
|
||||
if (output_prec == Precision::FP32 &&
|
||||
!node->getOriginalOutputPrecisions().empty() && node->getOriginalOutputPrecisionAtPort(0) != Precision::FP32) {
|
||||
if (algorithm != Algorithm::ReduceAnd && algorithm != Algorithm::ReduceOr &&
|
||||
algorithm != Algorithm::ReduceMin && algorithm != Algorithm::ReduceMax) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return canFuseSimpleOperation(node);
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,7 @@ enum ReduceLayoutType {
|
||||
struct jit_reduce_config_params {
|
||||
ReduceLayoutType layout;
|
||||
Algorithm reduce_mode;
|
||||
bool fuse_low_precision;
|
||||
dnnl::memory::data_type src_dt;
|
||||
dnnl::memory::data_type dst_dt;
|
||||
int src_data_size;
|
||||
@ -38,6 +39,8 @@ struct jit_reduce_call_args {
|
||||
size_t work_batch;
|
||||
size_t reduce_w = 2; // only used in planar layout [1: reduce width dimension] [0: reduce other dimension] [other value: N/A]
|
||||
size_t reduce_stride; // only used in planar layout while reducing dimensions except for width
|
||||
size_t can_divide; // if apply division in reduce_kernel [1: Yes] [0: No]
|
||||
const float *divisor; // mean = sum / divisor
|
||||
};
|
||||
|
||||
struct jit_reduce_post_call_args {
|
||||
@ -105,7 +108,7 @@ public:
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
void reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_size);
|
||||
void reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr);
|
||||
void reduce_PLN(const uint8_t *in_ptr, uint8_t *out_ptr);
|
||||
void reduce_BLK(const uint8_t *in_ptr, uint8_t *out_ptr);
|
||||
void reduce_BLK_concern_padding(const uint8_t *in_ptr, uint8_t *out_ptr);
|
||||
@ -114,6 +117,8 @@ private:
|
||||
inline void reduce_kernel_post_process(uint8_t *out_ptr);
|
||||
inline void reduce_kernel_reassign();
|
||||
inline void reduce_kernel_restore();
|
||||
inline void output_info_reassign(uint8_t *out_ptr);
|
||||
inline void output_info_restore(uint8_t **out_ptr);
|
||||
inline void init_dst_data(uint8_t *out_ptr, size_t dst_size);
|
||||
inline void create_hybrid_working_memory();
|
||||
inline void create_opt_working_memory();
|
||||
@ -131,8 +136,6 @@ private:
|
||||
bool canApplyJIT(const InferenceEngine::Precision &input_prec, const InferenceEngine::Precision &output_prec) const;
|
||||
|
||||
size_t blk_size;
|
||||
size_t dst_size;
|
||||
size_t prc_size;
|
||||
static const size_t REDUCE_DATA = 0;
|
||||
static const size_t REDUCE_INDEXES = 1;
|
||||
bool jit_beyond_5D = false;
|
||||
@ -140,6 +143,9 @@ private:
|
||||
bool keep_dims = true;
|
||||
bool is_hybrid_layout = false;
|
||||
bool compile_post_kernel = true;
|
||||
bool apply_post_kernel = true;
|
||||
bool apply_division = false;
|
||||
bool fuse_low_precision = false;
|
||||
bool support_split = false;
|
||||
bool precision_change = false;
|
||||
bool ReduceAll_opt = false;
|
||||
@ -151,14 +157,17 @@ private:
|
||||
size_t IB, IC, ID, IH, IW;
|
||||
size_t OB, OC, OD, OH, OW;
|
||||
size_t PD, PH, PW;
|
||||
size_t src_data_size, dst_data_size, prc_data_size;
|
||||
size_t src_data_size, dst_data_size, prc_data_size, intermediate_data_size, tmp_data_size;
|
||||
size_t dst_size, prc_size, intermediate_size, tmp_size;
|
||||
size_t reduce_stride;
|
||||
uint8_t *tmp_ptr;
|
||||
ReduceLayoutType layout;
|
||||
InferenceEngine::Precision input_prec, output_prec;
|
||||
InferenceEngine::Precision input_prec, output_prec, intermediate_prec, tmp_prec;
|
||||
InferenceEngine::SizeVector src_dims;
|
||||
InferenceEngine::SizeVector process_dst_dims;
|
||||
InferenceEngine::SizeVector axes_for_reduction;
|
||||
std::vector<int> raw_axes;
|
||||
std::vector<uint8_t> intermediate_buf;
|
||||
|
||||
jit_reduce_config_params jcp;
|
||||
jit_reduce_config_params aux_jcp;
|
||||
|
Loading…
Reference in New Issue
Block a user