[CPU] Add MVNLayoutType enum (#9889)
This commit is contained in:
parent
1be838576c
commit
f6162ed657
@ -63,8 +63,7 @@ size_t MVNKey::hash() const {
|
||||
seed = hash_combine(seed, mvnAttrs.epsMode_);
|
||||
seed = hash_combine(seed, mvnAttrs.src_prc.getPrecVal());
|
||||
seed = hash_combine(seed, mvnAttrs.dst_prc.getPrecVal());
|
||||
seed = hash_combine(seed, mvnAttrs.planar_layout);
|
||||
seed = hash_combine(seed, mvnAttrs.is_nhwc);
|
||||
seed = hash_combine(seed, mvnAttrs.layout);
|
||||
seed = hash_combine(seed, get_attr_hash(*attr.get()));
|
||||
return seed;
|
||||
}
|
||||
@ -79,8 +78,7 @@ bool MVNKey::operator==(const MVNKey& rhs) const {
|
||||
mvnAttrs.epsMode_ == rhs.mvnAttrs.epsMode_ &&
|
||||
mvnAttrs.src_prc == rhs.mvnAttrs.src_prc &&
|
||||
mvnAttrs.dst_prc == rhs.mvnAttrs.dst_prc &&
|
||||
mvnAttrs.is_nhwc == rhs.mvnAttrs.is_nhwc &&
|
||||
mvnAttrs.planar_layout == rhs.mvnAttrs.planar_layout;
|
||||
mvnAttrs.layout == rhs.mvnAttrs.layout;
|
||||
retVal = retVal && *attr.get() == *rhs.attr.get();
|
||||
return retVal;
|
||||
}
|
||||
@ -845,9 +843,9 @@ MKLDNNMVNNode::MVNJitExecutor::MVNJitExecutor(const MVNAttrs& mvnAttrs,
|
||||
auto jcp = jit_mvn_config_params();
|
||||
jcp.src_prc = mvnAttrs.src_prc;
|
||||
jcp.dst_prc = mvnAttrs.dst_prc;
|
||||
jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.src_prc));
|
||||
jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.dst_prc));
|
||||
jcp.planar_layout = mvnAttrs.planar_layout;
|
||||
jcp.src_data_size = src_data_size;
|
||||
jcp.dst_data_size = dst_data_size;
|
||||
jcp.planar_layout = mvnAttrs.layout == MVNLayoutType::planar;
|
||||
jcp.normalize_variance = mvnAttrs.normalizeVariance_;
|
||||
jcp.across_channels = mvnAttrs.execAcrossChannels_;
|
||||
int N = 0;
|
||||
@ -892,7 +890,7 @@ void MKLDNNMVNNode::MVNJitExecutor::exec(const uint8_t *src_data, uint8_t *dst_d
|
||||
if (!mvn_mean_kernel || (mvnAttrs.normalizeVariance_ && !mvn_variance_kernel) || !mvn_kernel) {
|
||||
IE_THROW() << "MVN layer doesn't create kernel to execute on sse41 above platform.";
|
||||
}
|
||||
if (mvnAttrs.planar_layout) {
|
||||
if (mvnAttrs.layout == MVNLayoutType::planar) {
|
||||
mvn_pln(src_data, dst_data, post_ops_data_);
|
||||
} else {
|
||||
mvn_blk(src_data, dst_data, post_ops_data_);
|
||||
@ -922,8 +920,13 @@ void MKLDNNMVNNode::prepareParams() {
|
||||
auto selectedPD = getSelectedPrimitiveDescriptor();
|
||||
mvnAttrs.src_prc = selectedPD->getConfig().inConfs[0].desc->getPrecision();
|
||||
mvnAttrs.dst_prc = selectedPD->getConfig().outConfs[0].desc->getPrecision();
|
||||
mvnAttrs.planar_layout = selectedPD->getConfig().inConfs[0].desc->hasLayoutType(LayoutType::ncsp);
|
||||
mvnAttrs.is_nhwc = getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc);
|
||||
if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::ncsp)) {
|
||||
mvnAttrs.layout = MVNLayoutType::planar;
|
||||
} else if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc)) {
|
||||
mvnAttrs.layout = MVNLayoutType::by_channel;
|
||||
} else {
|
||||
mvnAttrs.layout = MVNLayoutType::block;
|
||||
}
|
||||
}
|
||||
|
||||
MVNKey key = {mvnAttrs, mkldnn::primitive_attr()};
|
||||
@ -1279,7 +1282,8 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
|
||||
size_t CB = div_up(C, blk_size);
|
||||
|
||||
size_t C0 = mvnAttrs.is_nhwc ? W * C : W * blk_size;
|
||||
bool is_nhwc = mvnAttrs.layout == MVNLayoutType::by_channel;
|
||||
size_t C0 = is_nhwc ? W * C : W * blk_size;
|
||||
size_t C1 = C0 * H;
|
||||
size_t C2 = C1 * D;
|
||||
size_t C3 = C2 * CB;
|
||||
@ -1290,17 +1294,17 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
std::vector<float> mean_buffer(aux_buffer_size * threads_num);
|
||||
std::vector<float> variance_buffer(aux_buffer_size * threads_num);
|
||||
|
||||
size_t src_stride_size = mvnAttrs.is_nhwc ? static_cast<size_t>(C * src_data_size) : static_cast<size_t>(blk_size * src_data_size);
|
||||
size_t dst_stride_size = mvnAttrs.is_nhwc ? static_cast<size_t>(C * dst_data_size) : static_cast<size_t>(blk_size * dst_data_size);
|
||||
size_t src_stride_size = is_nhwc ? static_cast<size_t>(C * src_data_size) : static_cast<size_t>(blk_size * src_data_size);
|
||||
size_t dst_stride_size = is_nhwc ? static_cast<size_t>(C * dst_data_size) : static_cast<size_t>(blk_size * dst_data_size);
|
||||
|
||||
for (size_t b = 0lu; b < N; b++) {
|
||||
size_t b_offset = mvnAttrs.is_nhwc ? b * C5 : b * C3;
|
||||
size_t b_offset = is_nhwc ? b * C5 : b * C3;
|
||||
if (mvnAttrs.execAcrossChannels_) {
|
||||
// mean for this instance in batch
|
||||
float C5inv = 1.f / static_cast<float>(C5);
|
||||
float mean_temp = 0.0f;
|
||||
mean_temp = parallel_sum3d(CB, D, H, mean_temp, [&](size_t cb, size_t d, size_t h)->float {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
|
||||
float mean_internal = 0.0f;
|
||||
@ -1336,7 +1340,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
// variance: sum((x-mean)*(x-mean)) for one instance in batch
|
||||
float variance_temp = 0.0f;
|
||||
variance_temp = parallel_sum3d(CB, D, H, variance_temp, [&](size_t cb, size_t d, size_t h)->float {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
|
||||
float variance_internal = 0.0f;
|
||||
@ -1367,7 +1371,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
variance /= sqrtf(variance_temp * C5inv) + mvnAttrs.epsValue_;
|
||||
// mvn for one instance in batch
|
||||
parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
auto arg = jit_mvn_call_args();
|
||||
arg.src = src_data + src_offset * src_data_size;
|
||||
@ -1384,7 +1388,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
} else {
|
||||
// mvn for one instance in batch
|
||||
parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
auto arg = jit_mvn_call_args();
|
||||
arg.src = src_data + src_offset * src_data_size;
|
||||
@ -1407,7 +1411,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
// keep the compute order the same as planar
|
||||
parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
|
||||
for (size_t cb = 0; cb < CB; cb++) {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
auto mean_buffer_ptr = &mean_buffer[blk_size * cb + aux_buffer_size * thr_idx];
|
||||
|
||||
@ -1435,7 +1439,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
|
||||
parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
|
||||
for (size_t cb = 0; cb < CB; cb++) {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
|
||||
auto variance_buffer_ptr = &variance_buffer[blk_size * cb + aux_buffer_size * thr_idx];
|
||||
@ -1464,7 +1468,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
|
||||
parallel_for2d(D, H, [&](size_t d, size_t h) {
|
||||
for (size_t cb = 0; cb < CB; cb++) {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
|
||||
auto variance_buffer_ptr = &variance_buffer[blk_size * cb];
|
||||
@ -1486,7 +1490,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
|
||||
// normalizeVariance_ == false
|
||||
parallel_for2d(D, H, [&](size_t d, size_t h) {
|
||||
for (size_t cb = 0; cb < CB; cb++) {
|
||||
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
|
||||
: b_offset + cb * C2 + d * C1 + h * C0;
|
||||
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
|
||||
|
||||
|
@ -101,15 +101,20 @@ public:
|
||||
INSIDE_SQRT,
|
||||
OUTSIDE_SQRT
|
||||
};
|
||||
|
||||
enum MVNLayoutType {
|
||||
planar,
|
||||
block,
|
||||
by_channel
|
||||
};
|
||||
struct MVNAttrs {
|
||||
bool planar_layout;
|
||||
MVNLayoutType layout;
|
||||
std::tuple<size_t, size_t, size_t, size_t, size_t> shape5D;
|
||||
bool initAcrossChannels_;
|
||||
bool execAcrossChannels_;
|
||||
bool normalizeVariance_;
|
||||
float epsValue_;
|
||||
MVNEpsMode epsMode_;
|
||||
bool is_nhwc;
|
||||
InferenceEngine::Precision src_prc;
|
||||
InferenceEngine::Precision dst_prc;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user