[CPU] Add MVNLayoutType enum (#9889)

This commit is contained in:
Mang Guo 2022-01-27 19:27:40 +08:00 committed by GitHub
parent 1be838576c
commit f6162ed657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 24 deletions

View File

@ -63,8 +63,7 @@ size_t MVNKey::hash() const {
seed = hash_combine(seed, mvnAttrs.epsMode_);
seed = hash_combine(seed, mvnAttrs.src_prc.getPrecVal());
seed = hash_combine(seed, mvnAttrs.dst_prc.getPrecVal());
seed = hash_combine(seed, mvnAttrs.planar_layout);
seed = hash_combine(seed, mvnAttrs.is_nhwc);
seed = hash_combine(seed, mvnAttrs.layout);
seed = hash_combine(seed, get_attr_hash(*attr.get()));
return seed;
}
@ -79,8 +78,7 @@ bool MVNKey::operator==(const MVNKey& rhs) const {
mvnAttrs.epsMode_ == rhs.mvnAttrs.epsMode_ &&
mvnAttrs.src_prc == rhs.mvnAttrs.src_prc &&
mvnAttrs.dst_prc == rhs.mvnAttrs.dst_prc &&
mvnAttrs.is_nhwc == rhs.mvnAttrs.is_nhwc &&
mvnAttrs.planar_layout == rhs.mvnAttrs.planar_layout;
mvnAttrs.layout == rhs.mvnAttrs.layout;
retVal = retVal && *attr.get() == *rhs.attr.get();
return retVal;
}
@ -845,9 +843,9 @@ MKLDNNMVNNode::MVNJitExecutor::MVNJitExecutor(const MVNAttrs& mvnAttrs,
auto jcp = jit_mvn_config_params();
jcp.src_prc = mvnAttrs.src_prc;
jcp.dst_prc = mvnAttrs.dst_prc;
jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.src_prc));
jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.dst_prc));
jcp.planar_layout = mvnAttrs.planar_layout;
jcp.src_data_size = src_data_size;
jcp.dst_data_size = dst_data_size;
jcp.planar_layout = mvnAttrs.layout == MVNLayoutType::planar;
jcp.normalize_variance = mvnAttrs.normalizeVariance_;
jcp.across_channels = mvnAttrs.execAcrossChannels_;
int N = 0;
@ -892,7 +890,7 @@ void MKLDNNMVNNode::MVNJitExecutor::exec(const uint8_t *src_data, uint8_t *dst_d
if (!mvn_mean_kernel || (mvnAttrs.normalizeVariance_ && !mvn_variance_kernel) || !mvn_kernel) {
IE_THROW() << "MVN layer doesn't create kernel to execute on sse41 above platform.";
}
if (mvnAttrs.planar_layout) {
if (mvnAttrs.layout == MVNLayoutType::planar) {
mvn_pln(src_data, dst_data, post_ops_data_);
} else {
mvn_blk(src_data, dst_data, post_ops_data_);
@ -922,8 +920,13 @@ void MKLDNNMVNNode::prepareParams() {
auto selectedPD = getSelectedPrimitiveDescriptor();
mvnAttrs.src_prc = selectedPD->getConfig().inConfs[0].desc->getPrecision();
mvnAttrs.dst_prc = selectedPD->getConfig().outConfs[0].desc->getPrecision();
mvnAttrs.planar_layout = selectedPD->getConfig().inConfs[0].desc->hasLayoutType(LayoutType::ncsp);
mvnAttrs.is_nhwc = getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc);
if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::ncsp)) {
mvnAttrs.layout = MVNLayoutType::planar;
} else if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc)) {
mvnAttrs.layout = MVNLayoutType::by_channel;
} else {
mvnAttrs.layout = MVNLayoutType::block;
}
}
MVNKey key = {mvnAttrs, mkldnn::primitive_attr()};
@ -1279,7 +1282,8 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
size_t CB = div_up(C, blk_size);
size_t C0 = mvnAttrs.is_nhwc ? W * C : W * blk_size;
bool is_nhwc = mvnAttrs.layout == MVNLayoutType::by_channel;
size_t C0 = is_nhwc ? W * C : W * blk_size;
size_t C1 = C0 * H;
size_t C2 = C1 * D;
size_t C3 = C2 * CB;
@ -1290,17 +1294,17 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
std::vector<float> mean_buffer(aux_buffer_size * threads_num);
std::vector<float> variance_buffer(aux_buffer_size * threads_num);
size_t src_stride_size = mvnAttrs.is_nhwc ? static_cast<size_t>(C * src_data_size) : static_cast<size_t>(blk_size * src_data_size);
size_t dst_stride_size = mvnAttrs.is_nhwc ? static_cast<size_t>(C * dst_data_size) : static_cast<size_t>(blk_size * dst_data_size);
size_t src_stride_size = is_nhwc ? static_cast<size_t>(C * src_data_size) : static_cast<size_t>(blk_size * src_data_size);
size_t dst_stride_size = is_nhwc ? static_cast<size_t>(C * dst_data_size) : static_cast<size_t>(blk_size * dst_data_size);
for (size_t b = 0lu; b < N; b++) {
size_t b_offset = mvnAttrs.is_nhwc ? b * C5 : b * C3;
size_t b_offset = is_nhwc ? b * C5 : b * C3;
if (mvnAttrs.execAcrossChannels_) {
// mean for this instance in batch
float C5inv = 1.f / static_cast<float>(C5);
float mean_temp = 0.0f;
mean_temp = parallel_sum3d(CB, D, H, mean_temp, [&](size_t cb, size_t d, size_t h)->float {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
float mean_internal = 0.0f;
@ -1336,7 +1340,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
// variance: sum((x-mean)*(x-mean)) for one instance in batch
float variance_temp = 0.0f;
variance_temp = parallel_sum3d(CB, D, H, variance_temp, [&](size_t cb, size_t d, size_t h)->float {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
float variance_internal = 0.0f;
@ -1367,7 +1371,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
variance /= sqrtf(variance_temp * C5inv) + mvnAttrs.epsValue_;
// mvn for one instance in batch
parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto arg = jit_mvn_call_args();
arg.src = src_data + src_offset * src_data_size;
@ -1384,7 +1388,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
} else {
// mvn for one instance in batch
parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto arg = jit_mvn_call_args();
arg.src = src_data + src_offset * src_data_size;
@ -1407,7 +1411,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
// keep the compute order the same as planar
parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb + aux_buffer_size * thr_idx];
@ -1435,7 +1439,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
auto variance_buffer_ptr = &variance_buffer[blk_size * cb + aux_buffer_size * thr_idx];
@ -1464,7 +1468,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
parallel_for2d(D, H, [&](size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
auto variance_buffer_ptr = &variance_buffer[blk_size * cb];
@ -1486,7 +1490,7 @@ void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* ds
// normalizeVariance_ == false
parallel_for2d(D, H, [&](size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];

View File

@ -101,15 +101,20 @@ public:
INSIDE_SQRT,
OUTSIDE_SQRT
};
enum MVNLayoutType {
planar,
block,
by_channel
};
struct MVNAttrs {
bool planar_layout;
MVNLayoutType layout;
std::tuple<size_t, size_t, size_t, size_t, size_t> shape5D;
bool initAcrossChannels_;
bool execAcrossChannels_;
bool normalizeVariance_;
float epsValue_;
MVNEpsMode epsMode_;
bool is_nhwc;
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
};