[CPU] Cache for Reduce node (#9519)
This commit is contained in:
parent
456347e1b6
commit
e329baef04
@ -24,6 +24,7 @@
|
||||
#include <cpu/x64/injectors/jit_uni_eltwise_injector.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <common/primitive_hashing_utils.hpp>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace MKLDNNPlugin;
|
||||
@ -70,6 +71,35 @@ using namespace Xbyak;
|
||||
#define GET_PTR_NCD_BASE_PTR_N_BLK const uint8_t *in_ptr_ncd = in_ptr_n + src_data_size * (icb * ID + id) * IH * IW * blk_size; \
|
||||
uint8_t *out_ptr_ncd = out_ptr_n + dst_data_size * (ocb * OD + od) * OH * OW * blk_size;
|
||||
|
||||
namespace {
|
||||
struct ReduceKey {
|
||||
jit_reduce_config_params jcp;
|
||||
mkldnn::post_ops postOps;
|
||||
|
||||
size_t hash() const;
|
||||
bool operator==(const ReduceKey& rhs) const;
|
||||
};
|
||||
|
||||
size_t ReduceKey::hash() const {
|
||||
using namespace dnnl::impl;
|
||||
using namespace dnnl::impl::primitive_hashing;
|
||||
|
||||
size_t seed = 0;
|
||||
seed = hash_combine(seed, jcp.layout);
|
||||
seed = hash_combine(seed, jcp.reduce_mode);
|
||||
seed = hash_combine(seed, jcp.src_dt);
|
||||
seed = hash_combine(seed, jcp.dst_dt);
|
||||
seed = get_post_op_hash(seed, *postOps.get());
|
||||
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool ReduceKey::operator==(const ReduceKey &rhs) const {
|
||||
return jcp.layout == rhs.jcp.layout && jcp.reduce_mode == rhs.jcp.reduce_mode &&
|
||||
jcp.src_dt == rhs.jcp.src_dt && jcp.dst_dt == rhs.jcp.dst_dt && *postOps.get() == *rhs.postOps.get();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// some utility functions
|
||||
static inline bool isFloatCompatible(memory::data_type type) {
|
||||
return memory::data_type::f32 == type || memory::data_type::bf16 == type;
|
||||
@ -972,10 +1002,10 @@ private:
|
||||
}
|
||||
|
||||
inline void horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt, bool load_embedded) {
|
||||
uni_movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
|
||||
uni_movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
|
||||
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
|
||||
uni_vmovhlps(xmm_aux3, xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
|
||||
if (load_embedded) {
|
||||
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
|
||||
horiz_ps(xmm_dst, xmm_aux3);
|
||||
@ -1087,8 +1117,10 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi
|
||||
mov(reg_divisor, ptr[reg_params + GET_OFF_POST(divisor)]);
|
||||
if (!planar_layout)
|
||||
mov(reg_reduce_c, ptr[reg_params + GET_OFF_POST(reduce_c)]);
|
||||
if (attr_.post_ops_.len() != 0)
|
||||
if (attr_.post_ops_.len() != 0) {
|
||||
mov(reg_post_ops_data, ptr[reg_params + GET_OFF_POST(post_op_data)]);
|
||||
mov(reg_oc_off, ptr[reg_params + GET_OFF_POST(oc_off)]);
|
||||
}
|
||||
|
||||
if (isa == cpu::x64::avx512_common)
|
||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||
@ -1153,6 +1185,7 @@ private:
|
||||
Xbyak::Reg64 reg_oc_off = rax;
|
||||
Xbyak::Reg64 reg_d_weights = rbx;
|
||||
Xbyak::Reg64 reg_d_bias = rdx;
|
||||
Xbyak::Reg64 reg_post_ops_data = r15;
|
||||
|
||||
Vmm vmm_aux = Vmm(0);
|
||||
Xmm xmm_aux = Xmm(0);
|
||||
@ -1373,16 +1406,19 @@ private:
|
||||
int eltwise_inj_idx = 0;
|
||||
int depthwise_inj_idx = 0;
|
||||
int quantization_inj_idx = 0;
|
||||
int post_ops_data_offset = 0;
|
||||
for (int i = 0; i < p.len(); i++) {
|
||||
auto& post_op = p.entry_[i];
|
||||
if (post_op.is_eltwise()) {
|
||||
eltwise_injectors[eltwise_inj_idx]->compute_vector_range(vmm_dst.getIdx(), vmm_dst.getIdx() + 1);
|
||||
eltwise_inj_idx++;
|
||||
} else if (post_op.is_depthwise()) {
|
||||
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
|
||||
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
|
||||
mov(reg_d_weights, ptr[reg_post_ops_data + post_ops_data_offset]);
|
||||
add(reg_d_weights, reg_oc_off);
|
||||
post_ops_data_offset += sizeof(float*);
|
||||
mov(reg_d_bias, ptr[reg_post_ops_data + post_ops_data_offset]);
|
||||
add(reg_d_bias, reg_oc_off);
|
||||
post_ops_data_offset += sizeof(float*);
|
||||
depthwise_injectors[depthwise_inj_idx]->compute_vector_range(vmm_dst.getIdx(), vmm_dst.getIdx() + 1, reg_d_weights, reg_d_bias, is_broadcast);
|
||||
depthwise_inj_idx++;
|
||||
} else if (post_op.is_quantization()) {
|
||||
@ -1391,17 +1427,18 @@ private:
|
||||
|
||||
int s_idx = vmm_dst.getIdx();
|
||||
|
||||
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0, 0, is_broadcast);
|
||||
|
||||
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding, 0, is_broadcast);
|
||||
|
||||
if (do_dequantization) {
|
||||
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0, 0, is_broadcast);
|
||||
}
|
||||
|
||||
post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep();
|
||||
quantization_inj_idx++;
|
||||
}
|
||||
}
|
||||
@ -1584,10 +1621,10 @@ private:
|
||||
}
|
||||
|
||||
inline void horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt, bool load_embedded) {
|
||||
uni_movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
|
||||
uni_movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
|
||||
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
|
||||
uni_vmovhlps(xmm_aux3, xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
|
||||
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
|
||||
if (load_embedded) {
|
||||
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
|
||||
horiz_ps(xmm_dst, xmm_aux3);
|
||||
@ -1848,20 +1885,36 @@ void MKLDNNReduceNode::prepareParams() {
|
||||
set_reduce_dim_flags();
|
||||
}
|
||||
|
||||
auto builder = [&](const ReduceKey& key) -> std::shared_ptr<jit_uni_reduce_post_kernel> {
|
||||
std::shared_ptr<jit_uni_reduce_post_kernel> post_kernel;
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx512_common>(key.jcp, *attr.get()));
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx2>(key.jcp, *attr.get()));
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::sse41>(key.jcp, *attr.get()));
|
||||
}
|
||||
if (post_kernel)
|
||||
post_kernel->create_ker();
|
||||
|
||||
return post_kernel;
|
||||
};
|
||||
|
||||
if (compile_post_kernel) {
|
||||
setPostOps(attr, dst_dims, true);
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
reduce_post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx512_common>(jcp, *attr.get()));
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
reduce_post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::avx2>(jcp, *attr.get()));
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
reduce_post_kernel.reset(new jit_uni_reduce_post_kernel_f32<cpu::x64::sse41>(jcp, *attr.get()));
|
||||
|
||||
ReduceKey key = {jcp, attr.get_post_ops()};
|
||||
auto cache = getRuntimeCache();
|
||||
auto result = cache->getOrCreate(key, builder);
|
||||
if (!result.first) {
|
||||
IE_THROW() << errorPrefix << " has not found jit_uni_reduce_post_kernel_f32.";
|
||||
}
|
||||
if (reduce_post_kernel)
|
||||
reduce_post_kernel->create_ker();
|
||||
|
||||
reduce_post_kernel = result.first;
|
||||
jit_mode = jit_mode && reduce_post_kernel;
|
||||
|
||||
if (!isDynamicNode() || (isDynamicNode() && attr.get()->post_ops_.len() == 0)) {
|
||||
if (!isDynamicNode()) {
|
||||
compile_post_kernel = false;
|
||||
}
|
||||
}
|
||||
@ -2313,6 +2366,7 @@ inline void MKLDNNReduceNode::reduce_kernel_post_process(uint8_t *out_ptr) {
|
||||
arg.channel_size = layout == ReduceLayoutType::reduce_nspc ? OW : OC; // OW is related to nspc-ncsp dimension reinterpret
|
||||
arg.work_amount = OD * OH * OW;
|
||||
arg.divisor = &divisor;
|
||||
arg.post_op_data = static_cast<const void **>(&postOpsDataPtrs[0]);
|
||||
(*reduce_post_kernel)(&arg);
|
||||
});
|
||||
} else {
|
||||
@ -2325,6 +2379,7 @@ inline void MKLDNNReduceNode::reduce_kernel_post_process(uint8_t *out_ptr) {
|
||||
arg.oc_off = ocb * blk_size * sizeof(float);
|
||||
arg.work_amount = OD * OH * OW * blk_size;
|
||||
arg.divisor = &divisor;
|
||||
arg.post_op_data = static_cast<const void **>(&postOpsDataPtrs[0]);
|
||||
(*reduce_post_kernel)(&arg);
|
||||
});
|
||||
}
|
||||
@ -2784,6 +2839,23 @@ void MKLDNNReduceNode::setPostOps(mkldnn::primitive_attr &attr, const VectorDims
|
||||
}
|
||||
IE_THROW() << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented";
|
||||
}
|
||||
|
||||
postOpsDataPtrs.clear();
|
||||
for (int i = 0; i < ops.len(); ++i) {
|
||||
auto &post_op = ops.get()->entry_[i];
|
||||
if (post_op.is_quantization()) {
|
||||
auto &data = post_op.quantization.data;
|
||||
postOpsDataPtrs.insert(postOpsDataPtrs.end(), std::begin(data), std::end(data));
|
||||
memset(data, 0, sizeof(data));
|
||||
} else if (post_op.is_depthwise()) {
|
||||
auto &weights = post_op.depthwise.weights_data;
|
||||
auto &biases = post_op.depthwise.biases_data;
|
||||
postOpsDataPtrs.push_back(weights);
|
||||
postOpsDataPtrs.push_back(biases);
|
||||
weights = 0;
|
||||
biases = 0;
|
||||
}
|
||||
}
|
||||
attr.set_post_ops(ops);
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,7 @@ struct jit_reduce_post_call_args {
|
||||
size_t oc_off; // offset in byte along channel on output tensor
|
||||
size_t channel_size; // only for post ops fusion of nspc layout
|
||||
const float *divisor; // mean = sum / divisor
|
||||
const void** post_op_data;
|
||||
};
|
||||
|
||||
struct jit_uni_reduce_kernel {
|
||||
@ -146,6 +147,8 @@ private:
|
||||
|
||||
mkldnn::primitive_attr attr;
|
||||
|
||||
std::vector<const void*> postOpsDataPtrs;
|
||||
|
||||
std::shared_ptr<mkldnn::memory> prc_mem;
|
||||
|
||||
std::shared_ptr<jit_uni_reduce_kernel> reduce_kernel;
|
||||
|
2
src/plugins/intel_cpu/thirdparty/mkl-dnn
vendored
2
src/plugins/intel_cpu/thirdparty/mkl-dnn
vendored
@ -1 +1 @@
|
||||
Subproject commit f06708e9cf6c3973efee9d2a1a4df086050e1fcd
|
||||
Subproject commit 936ce259179bbd5c88adb492bfe8c5a99fe9df61
|
Loading…
Reference in New Issue
Block a user