[CPU] Add cache support for normalize (#9508)
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||
#include "utils/cpu_utils.hpp"
|
||||
#include <common/primitive_hashing_utils.hpp>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace MKLDNNPlugin;
|
||||
@@ -36,6 +37,44 @@ using namespace Xbyak;
|
||||
|
||||
#define THROW_ERROR IE_THROW() << "NormalizeL2 layer with name '" << getName() << "' "
|
||||
|
||||
namespace {
|
||||
struct NormalizeKey {
|
||||
MKLDNNNormalizeL2Node::NormalizeL2Attrs attrs;
|
||||
mkldnn::primitive_attr kernel_attrs;
|
||||
VectorDims dims;
|
||||
|
||||
size_t hash() const;
|
||||
bool operator==(const NormalizeKey& rhs) const;
|
||||
};
|
||||
|
||||
size_t NormalizeKey::hash() const {
|
||||
using namespace dnnl::impl;
|
||||
using namespace dnnl::impl::primitive_hashing;
|
||||
|
||||
size_t seed = 0;
|
||||
seed = hash_combine(seed, attrs.epsMode);
|
||||
seed = hash_combine(seed, attrs.across_spatial);
|
||||
seed = hash_combine(seed, attrs.cornerCase);
|
||||
seed = hash_combine(seed, attrs.eps);
|
||||
seed = hash_combine(seed, attrs.layout);
|
||||
seed = hash_combine(seed, attrs.input_prec.getPrecVal());
|
||||
seed = hash_combine(seed, attrs.output_prec.getPrecVal());
|
||||
|
||||
seed = hash_combine(seed, get_attr_hash(*kernel_attrs.get()));
|
||||
seed = get_vector_hash(seed, dims);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool NormalizeKey::operator==(const NormalizeKey& rhs) const {
|
||||
return (attrs.epsMode == rhs.attrs.epsMode) && (attrs.across_spatial == rhs.attrs.across_spatial) &&
|
||||
(attrs.cornerCase == rhs.attrs.cornerCase) && (attrs.eps == rhs.attrs.eps) &&
|
||||
(attrs.layout == rhs.attrs.layout) && (attrs.input_prec == rhs.attrs.input_prec) &&
|
||||
(attrs.output_prec == rhs.attrs.output_prec) && (*kernel_attrs.get() == *(rhs.kernel_attrs.get())) &&
|
||||
(dims == rhs.dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static inline bool isFloatCompatible(memory::data_type type) {
|
||||
return memory::data_type::f32 == type || memory::data_type::bf16 == type;
|
||||
}
|
||||
@@ -197,8 +236,10 @@ struct jit_uni_normalize_kernel_f32 : public jit_uni_normalize_kernel, public ji
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
|
||||
mov(reg_fused_factor, ptr[reg_params + GET_OFF(fused_factor)]);
|
||||
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
||||
if (attr_.post_ops_.len() != 0)
|
||||
if (attr_.post_ops_.len() != 0) {
|
||||
mov(reg_post_ops_data, ptr[reg_params + GET_OFF(post_op_data)]);
|
||||
mov(reg_oc_off, ptr[reg_params + GET_OFF(oc_off)]);
|
||||
}
|
||||
if (isa == avx512_common)
|
||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||
|
||||
@@ -234,7 +275,8 @@ private:
|
||||
Reg64 reg_tmp_64 = r14;
|
||||
|
||||
Xbyak::Reg64 reg_oc_off = rax;
|
||||
Xbyak::Reg64 reg_d_weights = rbx;
|
||||
Xbyak::Reg64 reg_post_ops_data = rbx;
|
||||
Xbyak::Reg64 reg_d_weights = reg_tmp_64;
|
||||
Xbyak::Reg64 reg_d_bias = rdx;
|
||||
|
||||
Vmm vmm_val = Vmm(0);
|
||||
@@ -602,6 +644,7 @@ private:
|
||||
int eltwise_inj_idx = 0;
|
||||
int depthwise_inj_idx = 0;
|
||||
int quantization_inj_idx = 0;
|
||||
int post_ops_data_offset = 0;
|
||||
for (int i = 0; i < p.len(); i++) {
|
||||
auto& post_op = p.entry_[i];
|
||||
if (post_op.is_eltwise()) {
|
||||
@@ -614,8 +657,12 @@ private:
|
||||
if (depthwise_injectors.size() <= depthwise_inj_idx
|
||||
|| depthwise_injectors[depthwise_inj_idx] == nullptr)
|
||||
assert(!"Invalid depthwise injectors.");
|
||||
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
|
||||
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
|
||||
mov(reg_d_weights, ptr[reg_post_ops_data + post_ops_data_offset]);
|
||||
post_ops_data_offset += sizeof(void*);
|
||||
|
||||
mov(reg_d_bias, ptr[reg_post_ops_data + post_ops_data_offset]);
|
||||
post_ops_data_offset += sizeof(void*);
|
||||
|
||||
add(reg_d_weights, reg_oc_off);
|
||||
add(reg_d_bias, reg_oc_off);
|
||||
// weight and bias is padding. scalar as vector.
|
||||
@@ -630,17 +677,19 @@ private:
|
||||
|
||||
int s_idx = vmm_val.getIdx();
|
||||
|
||||
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_oc_off);
|
||||
const Xbyak::RegExp quant_arg_base = reg_post_ops_data + post_ops_data_offset;
|
||||
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(quant_arg_base, reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0, 0, is_broadcast);
|
||||
|
||||
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(quant_arg_base, reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding, 0, is_broadcast);
|
||||
|
||||
if (do_dequantization) {
|
||||
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(quant_arg_base, reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0, 0, is_broadcast);
|
||||
}
|
||||
|
||||
post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep();
|
||||
quantization_inj_idx++;
|
||||
}
|
||||
}
|
||||
@@ -835,12 +884,13 @@ void MKLDNNNormalizeL2Node::createPrimitive() {
|
||||
|
||||
if (!attrs.cornerCase) {
|
||||
if (srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) {
|
||||
attrs.is_nchw = true;
|
||||
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nCsp16c) ||
|
||||
srcMemPtr->getDesc().hasLayoutType(LayoutType::nCsp8c)) {
|
||||
attrs.is_blk = true;
|
||||
attrs.layout = LayoutType::ncsp;
|
||||
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nCsp8c)) {
|
||||
attrs.layout = LayoutType::nCsp8c;
|
||||
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nCsp16c)) {
|
||||
attrs.layout = LayoutType::nCsp16c;
|
||||
} else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc)) {
|
||||
attrs.is_nhwc = true;
|
||||
attrs.layout = LayoutType::nspc;
|
||||
} else {
|
||||
THROW_ERROR << "has selected layout which is not supported";
|
||||
}
|
||||
@@ -859,8 +909,44 @@ bool MKLDNNNormalizeL2Node::isExecutable() const {
|
||||
|
||||
void MKLDNNNormalizeL2Node::prepareParams() {
|
||||
const auto& dims = getParentEdgeAt(DATA)->getMemoryPtr()->getStaticDims();
|
||||
|
||||
setPostOps(kernel_attrs, dims, true);
|
||||
execPtr = NormalizeL2Executor::getNormalizeL2Executor(attrs, kernel_attrs, dims);
|
||||
|
||||
// move pointer address from compile time kernel_attrs into runtime kernel args
|
||||
// and clear pointer address to remove it from cache key
|
||||
auto &postOps = (*kernel_attrs.get()).post_ops_;
|
||||
postOpsDataPtrs.clear();
|
||||
for (int i = 0; i < postOps.len(); ++i) {
|
||||
auto &post_op = postOps.entry_[i];
|
||||
if (post_op.is_quantization()) {
|
||||
auto &data = post_op.quantization.data;
|
||||
postOpsDataPtrs.insert(postOpsDataPtrs.end(), std::begin(data), std::end(data));
|
||||
memset(data, 0, sizeof(data));
|
||||
} else if (post_op.is_depthwise()) {
|
||||
auto &weights = post_op.depthwise.weights_data;
|
||||
auto &biases = post_op.depthwise.biases_data;
|
||||
postOpsDataPtrs.push_back(weights);
|
||||
postOpsDataPtrs.push_back(biases);
|
||||
weights = 0;
|
||||
biases = 0;
|
||||
}
|
||||
}
|
||||
|
||||
NormalizeKey key = {attrs, kernel_attrs, dims};
|
||||
|
||||
auto engine = getEngine();
|
||||
auto builder = [&engine](const NormalizeKey& key) -> std::shared_ptr<MKLDNNNormalizeL2Node::NormalizeL2Executor> {
|
||||
return NormalizeL2Executor::getNormalizeL2Executor(key.attrs, key.kernel_attrs, key.dims);
|
||||
};
|
||||
|
||||
auto cache = getRuntimeCache();
|
||||
auto result = cache->getOrCreate(key, builder);
|
||||
|
||||
if (!result.first) {
|
||||
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||
}
|
||||
|
||||
execPtr = result.first;
|
||||
}
|
||||
|
||||
void MKLDNNNormalizeL2Node::executeDynamicImpl(mkldnn::stream strm) {
|
||||
@@ -873,7 +959,7 @@ void MKLDNNNormalizeL2Node::execute(mkldnn::stream strm) {
|
||||
|
||||
const uint8_t *src_ptr = reinterpret_cast<const uint8_t *>(getParentEdgeAt(DATA)->getMemoryPtr()->GetPtr());
|
||||
uint8_t *dst_ptr = reinterpret_cast<uint8_t *>(getChildEdgeAt(DATA)->getMemoryPtr()->GetPtr());
|
||||
execPtr->exec(src_ptr, dst_ptr);
|
||||
execPtr->exec(src_ptr, dst_ptr, postOpsDataPtrs.data());
|
||||
}
|
||||
|
||||
std::vector<VectorDims> MKLDNNNormalizeL2Node::shapeInfer() const {
|
||||
@@ -889,7 +975,7 @@ public:
|
||||
workAmount = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
void exec(const uint8_t *src_ptr, uint8_t *dst_ptr) override {
|
||||
void exec(const uint8_t *src_ptr, uint8_t *dst_ptr, const void **post_ops_data) override {
|
||||
normalize(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr));
|
||||
}
|
||||
private:
|
||||
@@ -909,8 +995,12 @@ private:
|
||||
template <typename in_data_t, typename out_data_t>
|
||||
class MKLDNNNormalizeL2Node::NormalizeL2JitExecutor : public MKLDNNNormalizeL2Node::NormalizeL2Executor {
|
||||
public:
|
||||
NormalizeL2JitExecutor(const NormalizeL2Attrs& attrs_, const mkldnn::primitive_attr& kernel_attrs, const VectorDims& dims) : attrs(attrs_) {
|
||||
if (!attrs.is_nchw && !attrs.is_nhwc && !attrs.is_blk) {
|
||||
NormalizeL2JitExecutor(const NormalizeL2Attrs& attrs_,
|
||||
const mkldnn::primitive_attr& kernel_attrs,
|
||||
const VectorDims& dims)
|
||||
: attrs(attrs_) {
|
||||
if (attrs.layout != LayoutType::ncsp && attrs.layout != LayoutType::nspc &&
|
||||
attrs.layout != LayoutType::nCsp8c && attrs.layout != LayoutType::nCsp16c) {
|
||||
IE_THROW() << "Normalaize2L executor has selected layout which is not supported";
|
||||
}
|
||||
|
||||
@@ -920,9 +1010,9 @@ public:
|
||||
jcp.dst_data_size = attrs.output_prec.size();
|
||||
jcp.across_spatial = attrs.across_spatial;
|
||||
|
||||
jcp.is_nchw = attrs.is_nchw;
|
||||
jcp.is_nhwc = attrs.is_nhwc;
|
||||
jcp.is_blk = attrs.is_blk;
|
||||
jcp.is_nchw = (attrs.layout == LayoutType::ncsp);
|
||||
jcp.is_nhwc = (attrs.layout == LayoutType::nspc);
|
||||
jcp.is_blk = (attrs.layout == LayoutType::nCsp8c || attrs.layout == LayoutType::nCsp16c);
|
||||
|
||||
size_t dims_size = dims.size();
|
||||
jcp.n = dims[0];
|
||||
@@ -956,18 +1046,18 @@ public:
|
||||
normalize_modulo_kernel->create_ker();
|
||||
}
|
||||
|
||||
void exec(const uint8_t *src_ptr, uint8_t *dst_ptr) override {
|
||||
void exec(const uint8_t *src_ptr, uint8_t *dst_ptr, const void **post_ops_data) override {
|
||||
if (jcp.is_nchw) {
|
||||
normalize_nchw(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr));
|
||||
normalize_nchw(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr), post_ops_data);
|
||||
} else if (jcp.is_nhwc) {
|
||||
normalize_nhwc(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr));
|
||||
normalize_nhwc(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr), post_ops_data);
|
||||
} else if (jcp.is_blk) {
|
||||
normalize_blk(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr));
|
||||
normalize_blk(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr), post_ops_data);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void normalize_nchw(const in_data_t* src_data, out_data_t* dst_data) {
|
||||
void normalize_nchw(const in_data_t* src_data, out_data_t* dst_data, const void **post_ops_data) {
|
||||
const size_t spatial_dims = jcp.h * jcp.w;
|
||||
for (size_t b = 0lu; b < jcp.n; b++) {
|
||||
const in_data_t *src_data_b = src_data + b * jcp.c * spatial_dims;
|
||||
@@ -1011,6 +1101,7 @@ private:
|
||||
arg.fused_factor = static_cast<float*>(&modulo_inv); // broadcast once
|
||||
arg.oc_off = ic * sizeof(float);
|
||||
arg.work_amount = static_cast<size_t>(spatial_dims);
|
||||
arg.post_op_data = post_ops_data;
|
||||
(*normalize_kernel)(&arg);
|
||||
});
|
||||
} else { // across_spatial: false
|
||||
@@ -1051,13 +1142,14 @@ private:
|
||||
arg.fused_factor = static_cast<float*>(&moduloM[0]); // ld dynamic
|
||||
arg.oc_off = ic * sizeof(float);
|
||||
arg.work_amount = static_cast<size_t>(spatial_dims);
|
||||
arg.post_op_data = post_ops_data;
|
||||
(*normalize_kernel)(&arg);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void normalize_nhwc(const in_data_t* src_data, out_data_t* dst_data) {
|
||||
void normalize_nhwc(const in_data_t* src_data, out_data_t* dst_data, const void **post_ops_data) {
|
||||
const size_t spatial_dims = jcp.h * jcp.w;
|
||||
const size_t c_w_dims = jcp.c * jcp.w;
|
||||
for (size_t b = 0lu; b < jcp.n; b++) {
|
||||
@@ -1101,6 +1193,7 @@ private:
|
||||
arg.fused_factor = static_cast<float*>(&modulo_inv); // bc static
|
||||
arg.oc_off = 0;
|
||||
arg.work_amount = static_cast<size_t>(jcp.c);
|
||||
arg.post_op_data = post_ops_data;
|
||||
(*normalize_kernel)(&arg);
|
||||
});
|
||||
} else { // for across_spatial=false
|
||||
@@ -1131,13 +1224,14 @@ private:
|
||||
arg.fused_factor = static_cast<float*>(&modulo_inv); // bc static
|
||||
arg.work_amount = jcp.c;
|
||||
arg.oc_off = 0;
|
||||
arg.post_op_data = post_ops_data;
|
||||
(*normalize_kernel)(&arg);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void normalize_blk(const in_data_t* src_data, out_data_t* dst_data) {
|
||||
void normalize_blk(const in_data_t* src_data, out_data_t* dst_data, const void **post_ops_data) {
|
||||
const size_t CB = div_up(jcp.c, blk_size);
|
||||
const size_t spatial_dims = jcp.h * jcp.w;
|
||||
const size_t w_blk_dims = jcp.w * blk_size;
|
||||
@@ -1184,6 +1278,7 @@ private:
|
||||
arg.fused_factor = static_cast<float*>(&modulo_inv); // broadcast once
|
||||
arg.work_amount = static_cast<size_t>(jcp.w);
|
||||
arg.oc_off = cb * blk_size * sizeof(float);
|
||||
arg.post_op_data = post_ops_data;
|
||||
(*normalize_kernel)(&arg);
|
||||
});
|
||||
} else { // across_spatial: false
|
||||
@@ -1216,6 +1311,7 @@ private:
|
||||
arg.fused_factor = static_cast<float*>(&modulo_inv); // broadcast
|
||||
arg.work_amount = CB;
|
||||
arg.oc_off = 0;
|
||||
arg.post_op_data = post_ops_data;
|
||||
(*normalize_kernel)(&arg);
|
||||
});
|
||||
}
|
||||
@@ -1239,7 +1335,7 @@ class MKLDNNNormalizeL2Node::NormalizeL2ReferenceExecutor : public MKLDNNNormali
|
||||
public:
|
||||
NormalizeL2ReferenceExecutor(const NormalizeL2Attrs& attrs, const mkldnn::primitive_attr& kernel_attrs, const VectorDims& dims) :
|
||||
attrs(attrs), kernel_attrs(kernel_attrs), dims(dims) {
|
||||
if (!attrs.is_nchw) {
|
||||
if (attrs.layout != LayoutType::ncsp) {
|
||||
IE_THROW() << "Reference Executor of 'NormalizeL2' supports only ncsp layout!";
|
||||
}
|
||||
|
||||
@@ -1255,12 +1351,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void exec(const uint8_t *src_ptr, uint8_t *dst_ptr) override {
|
||||
normalize_nchw_ref(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr));
|
||||
void exec(const uint8_t *src_ptr, uint8_t *dst_ptr, const void **post_ops_data) override {
|
||||
normalize_nchw_ref(reinterpret_cast<const in_data_t*>(src_ptr), reinterpret_cast<out_data_t*>(dst_ptr), post_ops_data);
|
||||
}
|
||||
|
||||
private:
|
||||
void normalize_nchw_ref(const in_data_t* src_data, out_data_t* dst_data) {
|
||||
void normalize_nchw_ref(const in_data_t* src_data, out_data_t* dst_data, const void **post_ops_data) {
|
||||
size_t dims_size = dims.size();
|
||||
const size_t N = dims[0];
|
||||
const size_t C = dims[1];
|
||||
@@ -1292,7 +1388,7 @@ private:
|
||||
out_data_t *dst_data_bc = dst_data_b + ic * spatial_dims;
|
||||
for (size_t m = 0; m < spatial_dims; m++) {
|
||||
float dst_value = src_data_bc[m] * modulo_inv;
|
||||
apply_post_ops_scalar(dst_value, ic);
|
||||
apply_post_ops_scalar(dst_value, ic, post_ops_data);
|
||||
if (attrs.output_prec == Precision::U8) {
|
||||
dst_data_bc[m] = (dst_value >= 0) ? dst_value : 0;
|
||||
} else {
|
||||
@@ -1324,7 +1420,7 @@ private:
|
||||
out_data_t *dst_data_bc = dst_data_b + ic * spatial_dims;
|
||||
for (size_t m = 0; m < spatial_dims; m++) {
|
||||
float dst_value = src_data_bc[m] * moduloM[m];
|
||||
apply_post_ops_scalar(dst_value, ic);
|
||||
apply_post_ops_scalar(dst_value, ic, post_ops_data);
|
||||
if (attrs.output_prec == Precision::U8) {
|
||||
dst_data_bc[m] = (dst_value >= 0) ? dst_value : 0;
|
||||
} else {
|
||||
@@ -1336,18 +1432,22 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
inline void apply_post_ops_scalar(float &dst_value, int index_c) {
|
||||
inline void apply_post_ops_scalar(float &dst_value, int index_c, const void **post_ops_data_) {
|
||||
const auto &p = (*kernel_attrs.get()).post_ops_;
|
||||
int eltwise_inj_idx = 0;
|
||||
int depthwise_inj_idx = 0;
|
||||
// reinterpret cast from (pointer to const void) to (pointer to const pointer to const float)
|
||||
const float** post_ops_data = reinterpret_cast<const float**>(post_ops_data_);
|
||||
for (int i = 0; i < p.len(); i++) {
|
||||
auto &post_op = p.entry_[i];
|
||||
if (post_op.is_eltwise()) {
|
||||
dst_value = eltwise_injectors_ref[eltwise_inj_idx]->compute_scalar(dst_value);
|
||||
eltwise_inj_idx++;
|
||||
} else if (post_op.is_depthwise()) {
|
||||
auto depthwise_weights = post_op.depthwise.weights_data + index_c;
|
||||
auto depthwise_bias = post_op.depthwise.biases_data + index_c;
|
||||
auto depthwise_weights = post_ops_data[0] + index_c;
|
||||
auto depthwise_bias = post_ops_data[1] + index_c;
|
||||
post_ops_data += 2;
|
||||
|
||||
dst_value = depthwise_injectors_ref[depthwise_inj_idx]->compute_scalar(dst_value, depthwise_weights, depthwise_bias);
|
||||
depthwise_inj_idx++;
|
||||
} else if (post_op.is_quantization()) {
|
||||
@@ -1357,9 +1457,9 @@ private:
|
||||
auto quant = post_op.quantization;
|
||||
|
||||
using quantization_fields = post_ops_t::entry_t::quantization_t::quantization_fields;
|
||||
auto dataVal = [&](const quantization_fields& field) {
|
||||
auto dataVal = [&](const quantization_fields& field) -> float {
|
||||
const int channelIdx = quant.per_channel[field] ? index_c : 0;
|
||||
return quant.data[field][channelIdx];
|
||||
return post_ops_data[field][channelIdx];
|
||||
};
|
||||
|
||||
float crop_low = dataVal(quant.crop_low);
|
||||
@@ -1379,6 +1479,8 @@ private:
|
||||
float output_shift = dataVal(quant.output_shift);
|
||||
dst_value = dst_value * output_scale + output_shift;
|
||||
}
|
||||
|
||||
post_ops_data += quant.fields_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1419,7 +1521,7 @@ std::shared_ptr<MKLDNNNormalizeL2Node::NormalizeL2Executor> MKLDNNNormalizeL2Nod
|
||||
return std::make_shared<NormalizeL2CornerCaseExecutor<in_data_t, out_data_t>>(dims);
|
||||
else if (mayiuse(cpu::x64::sse41))
|
||||
return std::make_shared<NormalizeL2JitExecutor<in_data_t, out_data_t>>(attrs, kernel_attrs, dims);
|
||||
else if (attrs.is_nchw)
|
||||
else if (attrs.layout == LayoutType::ncsp)
|
||||
return std::make_shared<NormalizeL2ReferenceExecutor<in_data_t, out_data_t>>(attrs, kernel_attrs, dims);
|
||||
else
|
||||
IE_THROW() << "'NormalizeL2' cannot create Executor";
|
||||
|
||||
@@ -39,6 +39,8 @@ struct jit_normalize_call_args {
|
||||
size_t dst_stride;
|
||||
size_t work_amount;
|
||||
size_t oc_off;
|
||||
//ptr to array of post op inputs pointers (flat list)
|
||||
const void** post_op_data;
|
||||
};
|
||||
|
||||
struct jit_uni_normalize_modulo_kernel {
|
||||
@@ -95,32 +97,31 @@ public:
|
||||
|
||||
bool isExecutable() const override;
|
||||
|
||||
private:
|
||||
enum class NormEpsMode {
|
||||
ADD,
|
||||
MAX
|
||||
};
|
||||
|
||||
struct NormalizeL2Attrs {
|
||||
LayoutType layout = LayoutType::ncsp;
|
||||
NormEpsMode epsMode = NormEpsMode::ADD;
|
||||
bool across_spatial = true;
|
||||
bool cornerCase = false;
|
||||
float eps = 1e-10f;
|
||||
|
||||
bool is_nchw = false;
|
||||
bool is_nhwc = false;
|
||||
bool is_blk = false;
|
||||
|
||||
InferenceEngine::Precision input_prec = Precision::UNSPECIFIED;
|
||||
InferenceEngine::Precision output_prec = Precision::UNSPECIFIED;
|
||||
size_t src_data_size = 0lu;
|
||||
size_t dst_data_size = 0lu;
|
||||
} attrs;
|
||||
};
|
||||
|
||||
private:
|
||||
NormalizeL2Attrs attrs;
|
||||
|
||||
class NormalizeL2Executor {
|
||||
public:
|
||||
NormalizeL2Executor() = default;
|
||||
virtual void exec(const uint8_t *src_ptr, uint8_t *dst_ptr) = 0;
|
||||
virtual void exec(const uint8_t *src_ptr, uint8_t *dst_ptr, const void **post_ops_data) = 0;
|
||||
virtual ~NormalizeL2Executor() = default;
|
||||
|
||||
static std::shared_ptr<NormalizeL2Executor> getNormalizeL2Executor(const NormalizeL2Attrs& attrs,
|
||||
@@ -162,6 +163,8 @@ private:
|
||||
|
||||
mkldnn::primitive_attr kernel_attrs;
|
||||
|
||||
std::vector<const void*> postOpsDataPtrs;
|
||||
|
||||
void setPostOps(mkldnn::primitive_attr& kernel_attrs, const VectorDims& dims, bool initWeights = false);
|
||||
|
||||
static constexpr size_t DATA = 0;
|
||||
|
||||
@@ -117,7 +117,13 @@ std::vector<fusingSpecificParams> fusingParamsSet {
|
||||
std::vector<fusingSpecificParams> fusingParamsSetDynamic {
|
||||
emptyFusingSpec,
|
||||
fusingMultiplyPerTensor,
|
||||
fusingRelu
|
||||
fusingRelu,
|
||||
fusingFakeQuantizePerTensor
|
||||
};
|
||||
|
||||
std::vector<fusingSpecificParams> fusingParamsSetPerChannel {
|
||||
fusingPReluPerChannel,
|
||||
fusingFakeQuantizePerChannel
|
||||
};
|
||||
|
||||
const float epsilon = 1e-4f;
|
||||
@@ -136,10 +142,10 @@ const std::vector<ov::Shape> inputShapeStatic_2D = {
|
||||
|
||||
const std::vector<InputShape> inputShapeDynamic_2D = {
|
||||
{{-1, -1},
|
||||
{{2, 3}, {2, 3}, {5, 5}}},
|
||||
{{2, 3}, {2, 3}, {5, 5}, {2, 3}}},
|
||||
|
||||
{{-1, 5},
|
||||
{{5, 5}, {5, 5}, {12, 5}}},
|
||||
{{5, 5}, {5, 5}, {12, 5}, {5, 5}}},
|
||||
|
||||
{{{1, 5}, {8, 16}},
|
||||
{{3, 8}, {5, 16}, {3, 10}}}
|
||||
@@ -179,7 +185,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Dynamic_2D_FusingPerChannel, NormalizeL2LayerCPUT
|
||||
::testing::Values(epsilon),
|
||||
::testing::Values(epsMode),
|
||||
::testing::Values(CPUSpecificParams{}),
|
||||
::testing::Values(fusingPReluPerChannel)),
|
||||
::testing::ValuesIn(fusingParamsSetPerChannel)),
|
||||
NormalizeL2LayerCPUTest::getTestCaseName);
|
||||
|
||||
/* ============= 3D ============= */
|
||||
@@ -191,10 +197,10 @@ const std::vector<ov::Shape> inputShapeStatic_3D = {
|
||||
|
||||
const std::vector<InputShape> inputShapeDynamic_3D = {
|
||||
{{-1, -1, -1},
|
||||
{{2, 3, 4}, {2, 5, 5}, {1, 10, 2}}},
|
||||
{{2, 3, 4}, {2, 5, 5}, {1, 10, 2}, {2, 3, 4}}},
|
||||
|
||||
{{-1, 5, -1},
|
||||
{{1, 5, 5}, {2, 5, 3}, {5, 5, 5}}},
|
||||
{{1, 5, 5}, {2, 5, 3}, {5, 5, 5}, {1, 5, 5}}},
|
||||
|
||||
{{{1, 5}, {5, 10}, {5, 10}},
|
||||
{{3, 8, 8}, {5, 5, 10}, {5, 5, 10}, {5, 10, 10}}}
|
||||
@@ -236,7 +242,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Dynamic_3D_FusingPerChannel, NormalizeL2LayerCPUT
|
||||
::testing::Values(epsilon),
|
||||
::testing::Values(epsMode),
|
||||
::testing::Values(CPUSpecificParams{}),
|
||||
::testing::Values(fusingPReluPerChannel)),
|
||||
::testing::ValuesIn(fusingParamsSetPerChannel)),
|
||||
NormalizeL2LayerCPUTest::getTestCaseName);
|
||||
|
||||
/* ============= 4D ============= */
|
||||
@@ -248,10 +254,10 @@ const std::vector<ov::Shape> inputShapeStatic_4D = {
|
||||
|
||||
const std::vector<InputShape> inputShapeDynamic_4D = {
|
||||
{{-1, -1, -1, -1},
|
||||
{{2, 3, 4, 5}, {2, 5, 5, 5}, {1, 16, 2, 4}}},
|
||||
{{2, 3, 4, 5}, {2, 5, 5, 5}, {1, 16, 2, 4}, {2, 3, 4, 5}}},
|
||||
|
||||
{{-1, 5, -1, -1},
|
||||
{{1, 5, 5, 8}, {1, 5, 5, 8}, {3, 5, 8, 8}}},
|
||||
{{1, 5, 5, 8}, {1, 5, 5, 8}, {3, 5, 8, 8}, {1, 5, 5, 8}}},
|
||||
|
||||
{{{1, 5}, {5, 16}, {5, 10}, {5, 10}},
|
||||
{{3, 8, 8, 8}, {5, 7, 10, 10}, {1, 16, 7, 9}, {5, 9, 10, 5}}}
|
||||
@@ -307,7 +313,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Dynamic_4D_FusingPerChannel, NormalizeL2LayerCPUT
|
||||
::testing::Values(epsilon),
|
||||
::testing::Values(epsMode),
|
||||
::testing::ValuesIn(getCPUSpecificParams()),
|
||||
::testing::Values(fusingPReluPerChannel)),
|
||||
::testing::ValuesIn(fusingParamsSetPerChannel)),
|
||||
NormalizeL2LayerCPUTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user