[CPU] Cache for Reduce node (#9519)

This commit is contained in:
Chen Xu 2022-01-13 15:42:24 +08:00 committed by GitHub
parent 456347e1b6
commit e329baef04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 24 deletions

View File

@ -24,6 +24,7 @@
#include <cpu/x64/injectors/jit_uni_eltwise_injector.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <common/primitive_hashing_utils.hpp>
using namespace mkldnn;
using namespace MKLDNNPlugin;
@ -70,6 +71,35 @@ using namespace Xbyak;
#define GET_PTR_NCD_BASE_PTR_N_BLK const uint8_t *in_ptr_ncd = in_ptr_n + src_data_size * (icb * ID + id) * IH * IW * blk_size; \
uint8_t *out_ptr_ncd = out_ptr_n + dst_data_size * (ocb * OD + od) * OH * OW * blk_size;
namespace {
struct ReduceKey {
jit_reduce_config_params jcp;
mkldnn::post_ops postOps;
size_t hash() const;
bool operator==(const ReduceKey& rhs) const;
};
size_t ReduceKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, jcp.layout);
seed = hash_combine(seed, jcp.reduce_mode);
seed = hash_combine(seed, jcp.src_dt);
seed = hash_combine(seed, jcp.dst_dt);
seed = get_post_op_hash(seed, *postOps.get());
return seed;
}
bool ReduceKey::operator==(const ReduceKey &rhs) const {
return jcp.layout == rhs.jcp.layout && jcp.reduce_mode == rhs.jcp.reduce_mode &&
jcp.src_dt == rhs.jcp.src_dt && jcp.dst_dt == rhs.jcp.dst_dt && *postOps.get() == *rhs.postOps.get();
}
} // namespace
// some utility functions
static inline bool isFloatCompatible(memory::data_type type) {
return memory::data_type::f32 == type || memory::data_type::bf16 == type;
@ -972,10 +1002,10 @@ private:
}
inline void horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt, bool load_embedded) {
uni_movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_vmovhlps(xmm_aux3, xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
if (load_embedded) {
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
@ -1087,8 +1117,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
mov(reg_divisor, ptr[reg_params + GET_OFF_POST(divisor)]);
if (!planar_layout)
mov(reg_reduce_c, ptr[reg_params + GET_OFF_POST(reduce_c)]);
if (attr_.post_ops_.len() != 0)
if (attr_.post_ops_.len() != 0) {
mov(reg_post_ops_data, ptr[reg_params + GET_OFF_POST(post_op_data)]);
mov(reg_oc_off, ptr[reg_params + GET_OFF_POST(oc_off)]);
}
if (isa == cpu::x64::avx512_common)
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
@ -1153,6 +1185,7 @@ private:
Xbyak::Reg64 reg_oc_off = rax;
Xbyak::Reg64 reg_d_weights = rbx;
Xbyak::Reg64 reg_d_bias = rdx;
Xbyak::Reg64 reg_post_ops_data = r15;
Vmm vmm_aux = Vmm(0);
Xmm xmm_aux = Xmm(0);
@ -1373,16 +1406,19 @@ private:
int eltwise_inj_idx = 0;
int depthwise_inj_idx = 0;
int quantization_inj_idx = 0;
int post_ops_data_offset = 0;
for (int i = 0; i < p.len(); i++) {
auto& post_op = p.entry_[i];
if (post_op.is_eltwise()) {
eltwise_injectors[eltwise_inj_idx]->compute_vector_range(vmm_dst.getIdx(), vmm_dst.getIdx() + 1);
eltwise_inj_idx++;
} else if (post_op.is_depthwise()) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
mov(reg_d_weights, ptr[reg_post_ops_data + post_ops_data_offset]);
add(reg_d_weights, reg_oc_off);
post_ops_data_offset += sizeof(float*);
mov(reg_d_bias, ptr[reg_post_ops_data + post_ops_data_offset]);
add(reg_d_bias, reg_oc_off);
post_ops_data_offset += sizeof(float*);
depthwise_injectors[depthwise_inj_idx]->compute_vector_range(vmm_dst.getIdx(), vmm_dst.getIdx() + 1, reg_d_weights, reg_d_bias, is_broadcast);
depthwise_inj_idx++;
} else if (post_op.is_quantization()) {
@ -1391,17 +1427,18 @@ private:
int s_idx = vmm_dst.getIdx();
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_oc_off);
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0, 0, is_broadcast);
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_oc_off);
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding, 0, is_broadcast);
if (do_dequantization) {
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_oc_off);
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0, 0, is_broadcast);
}
post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep();
quantization_inj_idx++;
}
}
@ -1584,10 +1621,10 @@ private:
}
inline void horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt, bool load_embedded) {
uni_movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_vmovhlps(xmm_aux3, xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
if (load_embedded) {
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
@ -1848,20 +1885,36 @@ void MKLDNNReduceNode::prepareParams() {
set_reduce_dim_flags();
}
auto builder = [&](const ReduceKey& key) -> std::shared_ptr<jit_uni_reduce_post_kernel> {
std::shared_ptr<jit_uni_reduce_post_kernel> post_kernel;
if (mayiuse(cpu::x64::avx512_common)) {
post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx512_common>(key.jcp, *attr.get()));
} else if (mayiuse(cpu::x64::avx2)) {
post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx2>(key.jcp, *attr.get()));
} else if (mayiuse(cpu::x64::sse41)) {
post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::sse41>(key.jcp, *attr.get()));
}
if (post_kernel)
post_kernel->create_ker();
return post_kernel;
};
if (compile_post_kernel) {
setPostOps(attr, dst_dims, true);
if (mayiuse(cpu::x64::avx512_common)) {
reduce_post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx512_common>(jcp, *attr.get()));
} else if (mayiuse(cpu::x64::avx2)) {
reduce_post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx2>(jcp, *attr.get()));
} else if (mayiuse(cpu::x64::sse41)) {
reduce_post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::sse41>(jcp, *attr.get()));
ReduceKey key = {jcp, attr.get_post_ops()};
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, builder);
if (!result.first) {
IE_THROW() << errorPrefix << " has not found jit_uni_reduce_post_kernel_f32.";
}
if (reduce_post_kernel)
reduce_post_kernel->create_ker();
reduce_post_kernel = result.first;
jit_mode = jit_mode && reduce_post_kernel;
if (!isDynamicNode() || (isDynamicNode() && attr.get()->post_ops_.len() == 0)) {
if (!isDynamicNode()) {
compile_post_kernel = false;
}
}
@ -2313,6 +2366,7 @@ inline void MKLDNNReduceNode::reduce_kernel_post_process(uint8_t *out_ptr) {
arg.channel_size = layout == ReduceLayoutType::reduce_nspc ? OW : OC; // OW is related to nspc-ncsp dimension reinterpret
arg.work_amount = OD * OH * OW;
arg.divisor = &divisor;
arg.post_op_data = static_cast<const void **>(&postOpsDataPtrs[0]);
(*reduce_post_kernel)(&arg);
});
} else {
@ -2325,6 +2379,7 @@ inline void MKLDNNReduceNode::reduce_kernel_post_process(uint8_t *out_ptr) {
arg.oc_off = ocb * blk_size * sizeof(float);
arg.work_amount = OD * OH * OW * blk_size;
arg.divisor = &divisor;
arg.post_op_data = static_cast<const void **>(&postOpsDataPtrs[0]);
(*reduce_post_kernel)(&arg);
});
}
@ -2784,6 +2839,23 @@ void MKLDNNReduceNode::setPostOps(mkldnn::primitive_attr &attr, const VectorDims
}
IE_THROW() << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented";
}
postOpsDataPtrs.clear();
for (int i = 0; i < ops.len(); ++i) {
auto &post_op = ops.get()->entry_[i];
if (post_op.is_quantization()) {
auto &data = post_op.quantization.data;
postOpsDataPtrs.insert(postOpsDataPtrs.end(), std::begin(data), std::end(data));
memset(data, 0, sizeof(data));
} else if (post_op.is_depthwise()) {
auto &weights = post_op.depthwise.weights_data;
auto &biases = post_op.depthwise.biases_data;
postOpsDataPtrs.push_back(weights);
postOpsDataPtrs.push_back(biases);
weights = 0;
biases = 0;
}
}
attr.set_post_ops(ops);
}

View File

@ -45,6 +45,7 @@ struct jit_reduce_post_call_args {
size_t oc_off; // offset in byte along channel on output tensor
size_t channel_size; // only for post ops fusion of nspc layout
const float *divisor; // mean = sum / divisor
const void** post_op_data;
};
struct jit_uni_reduce_kernel {
@ -146,6 +147,8 @@ private:
mkldnn::primitive_attr attr;
std::vector<const void*> postOpsDataPtrs;
std::shared_ptr<mkldnn::memory> prc_mem;
std::shared_ptr<jit_uni_reduce_kernel> reduce_kernel;

@ -1 +1 @@
Subproject commit f06708e9cf6c3973efee9d2a1a4df086050e1fcd
Subproject commit 936ce259179bbd5c88adb492bfe8c5a99fe9df61