[CPU] Enable cache for MVN (#9549)

This commit is contained in:
Mang Guo 2022-01-24 20:07:35 +08:00 committed by GitHub
parent 7114253d17
commit 413fee2a86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 825 additions and 580 deletions

View File

@ -50,7 +50,7 @@ bool LrnKey::operator==(const LrnKey &rhs) const {
retVal = retVal && inp0 && rhs.inp0 && inp0->getDnnlDesc() == rhs.inp0->getDnnlDesc();
}
retVal = retVal && implType == rhs.implType && alg == rhs.alg && alg == rhs.alg && size == rhs.size && k == rhs.k &&
retVal = retVal && implType == rhs.implType && alg == rhs.alg && size == rhs.size && k == rhs.k &&
alpha == rhs.alpha && beta == rhs.beta;
return retVal;
}

View File

@ -36,6 +36,56 @@ using namespace Xbyak;
#define GET_OFF(field) offsetof(jit_mvn_call_args, field)
namespace {
struct MVNKey {
MKLDNNMVNNode::MVNAttrs mvnAttrs;
mkldnn::primitive_attr attr;
size_t hash() const;
bool operator==(const MVNKey& rhs) const;
};
size_t MVNKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0;
seed = hash_combine(seed, std::get<0>(mvnAttrs.shape5D));
seed = hash_combine(seed, std::get<1>(mvnAttrs.shape5D));
seed = hash_combine(seed, std::get<2>(mvnAttrs.shape5D));
seed = hash_combine(seed, std::get<3>(mvnAttrs.shape5D));
seed = hash_combine(seed, std::get<4>(mvnAttrs.shape5D));
seed = hash_combine(seed, mvnAttrs.initAcrossChannels_);
seed = hash_combine(seed, mvnAttrs.execAcrossChannels_);
seed = hash_combine(seed, mvnAttrs.normalizeVariance_);
seed = hash_combine(seed, mvnAttrs.epsValue_);
seed = hash_combine(seed, mvnAttrs.epsMode_);
seed = hash_combine(seed, mvnAttrs.src_prc.getPrecVal());
seed = hash_combine(seed, mvnAttrs.dst_prc.getPrecVal());
seed = hash_combine(seed, mvnAttrs.planar_layout);
seed = hash_combine(seed, mvnAttrs.is_nhwc);
seed = hash_combine(seed, get_attr_hash(*attr.get()));
return seed;
}
bool MVNKey::operator==(const MVNKey& rhs) const {
bool retVal = true;
retVal = retVal && mvnAttrs.shape5D == rhs.mvnAttrs.shape5D &&
mvnAttrs.initAcrossChannels_ == rhs.mvnAttrs.initAcrossChannels_ &&
mvnAttrs.execAcrossChannels_ == rhs.mvnAttrs.execAcrossChannels_ &&
mvnAttrs.normalizeVariance_ == rhs.mvnAttrs.normalizeVariance_ &&
mvnAttrs.epsValue_ == rhs.mvnAttrs.epsValue_ &&
mvnAttrs.epsMode_ == rhs.mvnAttrs.epsMode_ &&
mvnAttrs.src_prc == rhs.mvnAttrs.src_prc &&
mvnAttrs.dst_prc == rhs.mvnAttrs.dst_prc &&
mvnAttrs.is_nhwc == rhs.mvnAttrs.is_nhwc &&
mvnAttrs.planar_layout == mvnAttrs.planar_layout;
retVal = retVal && *attr.get() == *rhs.attr.get();
return retVal;
}
} // namespace
// some utility functions
static inline bool isFloatCompatible(Precision prc) {
return Precision::FP32 == prc || Precision::BF16 == prc;
@ -389,6 +439,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
this->preamble();
mov(reg_post_ops_data, ptr[reg_params + GET_OFF(post_op_data)]);
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
mov(reg_mean, ptr[reg_params + GET_OFF(mean)]);
if (jcp_.normalize_variance)
@ -505,6 +556,7 @@ private:
Xbyak::Reg64 reg_oc_off = rax;
Xbyak::Reg64 reg_d_weights = rbx;
Xbyak::Reg64 reg_d_bias = rdx;
Xbyak::Reg64 reg_post_ops_data = rsi;
Xbyak::Reg64 reg_load_table = r15;
Xbyak::Reg64 reg_load_store_mask = rbp;
@ -570,16 +622,19 @@ private:
int eltwise_inj_idx = 0;
int depthwise_inj_idx = 0;
int quantization_inj_idx = 0;
int post_ops_data_offset = 0;
for (int i = 0; i < p.len(); i++) {
auto& post_op = p.entry_[i];
if (post_op.is_eltwise()) {
eltwise_injectors[eltwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1);
eltwise_inj_idx++;
} else if (post_op.is_depthwise()) {
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
mov(reg_d_weights, ptr[reg_post_ops_data + post_ops_data_offset]);
add(reg_d_weights, reg_oc_off);
post_ops_data_offset += sizeof(float*);
mov(reg_d_bias, ptr[reg_post_ops_data + post_ops_data_offset]);
add(reg_d_bias, reg_oc_off);
post_ops_data_offset += sizeof(float*);
depthwise_injectors[depthwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1, reg_d_weights, reg_d_bias, is_broadcast);
depthwise_inj_idx++;
} else if (post_op.is_quantization()) {
@ -587,15 +642,16 @@ private:
bool do_rounding = do_dequantization || isFloatCompatible(dst_prc) || i != p.len() - 1;
int s_idx = vmm_val.getIdx();
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_oc_off);
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0, 0, is_broadcast);
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_oc_off);
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding, 0, is_broadcast);
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_oc_off);
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off);
quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0, 0, is_broadcast);
post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep();
quantization_inj_idx++;
}
}
@ -676,24 +732,24 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr<ngraph::Node>& op, const mkld
IE_THROW(NotImplemented) << errorMessage;
}
epsMode_ = INSIDE_SQRT;
mvnAttrs.epsMode_ = INSIDE_SQRT;
if (auto mvnOp = ngraph::as_type_ptr<ngraph::op::v6::MVN>(op)) {
normalizeVariance_ = mvnOp->get_normalize_variance();
epsValue_ = mvnOp->get_eps();
mvnAttrs.normalizeVariance_ = mvnOp->get_normalize_variance();
mvnAttrs.epsValue_ = mvnOp->get_eps();
if (mvnOp->get_eps_mode() == ngraph::op::MVNEpsMode::OUTSIDE_SQRT) {
epsMode_ = OUTSIDE_SQRT;
mvnAttrs.epsMode_ = OUTSIDE_SQRT;
}
initAcrossChannels_ = false;
mvnAttrs.initAcrossChannels_ = false;
const auto& inDataShapeSize = getInputShapeAtPort(0).getRank();
if (inDataShapeSize == mvnOp->input_value(1).get_shape()[0] + 1 || inDataShapeSize == 1)
initAcrossChannels_ = true;
mvnAttrs.initAcrossChannels_ = true;
} else if (auto mvnOp = ngraph::as_type_ptr<ngraph::op::v0::MVN>(op)) {
normalizeVariance_ = mvnOp->get_normalize_variance();
epsValue_ = mvnOp->get_eps();
initAcrossChannels_ = mvnOp->get_across_channels();
mvnAttrs.normalizeVariance_ = mvnOp->get_normalize_variance();
mvnAttrs.epsValue_ = mvnOp->get_eps();
mvnAttrs.initAcrossChannels_ = mvnOp->get_across_channels();
}
execAcrossChannels_ = initAcrossChannels_;
mvnAttrs.execAcrossChannels_ = mvnAttrs.initAcrossChannels_;
}
void MKLDNNMVNNode::getSupportedDescriptors() {}
@ -718,11 +774,8 @@ void MKLDNNMVNNode::initSupportedPrimitiveDescriptors() {
inputPrecision = outputPrecision = Precision::FP32;
}
src_data_size = inputPrecision.size();
dst_data_size = outputPrecision.size();
// TODO [DS]: inplace
bool canBeInplace = !isDynamicNode() && (src_data_size == dst_data_size) &&
bool canBeInplace = !isDynamicNode() && (inputPrecision.size() == outputPrecision.size()) &&
(getParentEdgeAt(0)->getParent()->getChildEdges().size() == 1) &&
!getParentEdgeAt(0)->getParent()->isConstant();
@ -781,6 +834,77 @@ void MKLDNNMVNNode::initSupportedPrimitiveDescriptors() {
pushDesc(LayoutType::ncsp, impl_type);
}
MKLDNNMVNNode::MVNExecutor::MVNExecutor(const MVNAttrs& mvnAttrs)
: mvnAttrs(mvnAttrs),
src_data_size(mvnAttrs.src_prc.size()),
dst_data_size(mvnAttrs.dst_prc.size()) {}
MKLDNNMVNNode::MVNJitExecutor::MVNJitExecutor(const MVNAttrs& mvnAttrs,
const mkldnn::primitive_attr& attr):
MVNExecutor(mvnAttrs) {
auto jcp = jit_mvn_config_params();
jcp.src_prc = mvnAttrs.src_prc;
jcp.dst_prc = mvnAttrs.dst_prc;
jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.src_prc));
jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.dst_prc));
jcp.planar_layout = mvnAttrs.planar_layout;
jcp.normalize_variance = mvnAttrs.normalizeVariance_;
jcp.across_channels = mvnAttrs.execAcrossChannels_;
int N = 0;
std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = mvnAttrs.shape5D;
if (mayiuse(cpu::x64::avx512_common)) {
mvn_kernel.reset(new jit_uni_mvn_kernel_f32<cpu::x64::avx512_common>(jcp, *attr.get()));
jcp.normalize_variance = false;
mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx512_common>(jcp));
if (mvnAttrs.normalizeVariance_) {
jcp.normalize_variance = true;
mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx512_common>(jcp));
}
} else if (mayiuse(cpu::x64::avx2)) {
mvn_kernel.reset(new jit_uni_mvn_kernel_f32<cpu::x64::avx2>(jcp, *attr.get()));
jcp.normalize_variance = false;
mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx2>(jcp));
if (mvnAttrs.normalizeVariance_) {
jcp.normalize_variance = true;
mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx2>(jcp));
}
} else if (mayiuse(cpu::x64::sse41)) {
mvn_kernel.reset(new jit_uni_mvn_kernel_f32<cpu::x64::sse41>(jcp, *attr.get()));
jcp.normalize_variance = false;
mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::sse41>(jcp));
if (mvnAttrs.normalizeVariance_) {
jcp.normalize_variance = true;
mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::sse41>(jcp));
}
} else {
IE_THROW() << "Can't create jit MVN kernel";
}
if (mvn_kernel)
mvn_kernel->create_ker();
if (mvn_mean_kernel)
mvn_mean_kernel->create_ker();
if (mvn_variance_kernel)
mvn_variance_kernel->create_ker();
}
void MKLDNNMVNNode::MVNJitExecutor::exec(const uint8_t *src_data, uint8_t *dst_data, const void *post_ops_data_) {
if (!mvn_mean_kernel || (mvnAttrs.normalizeVariance_ && !mvn_variance_kernel) || !mvn_kernel) {
IE_THROW() << "MVN layer doesn't create kernel to execute on sse41 above platform.";
}
if (mvnAttrs.planar_layout) {
mvn_pln(src_data, dst_data, post_ops_data_);
} else {
mvn_blk(src_data, dst_data, post_ops_data_);
}
}
MKLDNNMVNNode::MVNRefExecutor::MVNRefExecutor(const MVNAttrs& mvnAttrs):MVNExecutor(mvnAttrs) {}
void MKLDNNMVNNode::MVNRefExecutor::exec(const uint8_t *src_data, uint8_t *dst_data, const void *post_ops_data_) {
mvn_ref(src_data, dst_data);
}
void MKLDNNMVNNode::prepareParams() {
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
@ -794,59 +918,48 @@ void MKLDNNMVNNode::prepareParams() {
const SizeVector in_dims = srcMemPtr->getStaticDims();
transformTo5DCase(in_dims);
setPostOps(attr, true);
if (mayiuse(cpu::x64::sse41)) {
auto selectedPD = getSelectedPrimitiveDescriptor();
auto jcp = jit_mvn_config_params();
jcp.src_prc = selectedPD->getConfig().inConfs[0].desc->getPrecision();
jcp.dst_prc = selectedPD->getConfig().outConfs[0].desc->getPrecision();
jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.src_prc));
jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.dst_prc));
jcp.planar_layout = selectedPD->getConfig().inConfs[0].desc->hasLayoutType(LayoutType::ncsp);
jcp.normalize_variance = normalizeVariance_;
jcp.across_channels = execAcrossChannels_;
int N = 0;
std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = shape5D;
mvnAttrs.src_prc = selectedPD->getConfig().inConfs[0].desc->getPrecision();
mvnAttrs.dst_prc = selectedPD->getConfig().outConfs[0].desc->getPrecision();
mvnAttrs.planar_layout = selectedPD->getConfig().inConfs[0].desc->hasLayoutType(LayoutType::ncsp);
mvnAttrs.is_nhwc = getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc);
}
if (mayiuse(cpu::x64::avx512_common)) {
mvn_kernel.reset(new jit_uni_mvn_kernel_f32<cpu::x64::avx512_common>(jcp, *attr.get()));
MVNKey key = {mvnAttrs, mkldnn::primitive_attr()};
setPostOps(key.attr, true);
jcp.normalize_variance = false;
mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx512_common>(jcp));
if (normalizeVariance_) {
jcp.normalize_variance = true;
mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx512_common>(jcp));
}
} else if (mayiuse(cpu::x64::avx2)) {
mvn_kernel.reset(new jit_uni_mvn_kernel_f32<cpu::x64::avx2>(jcp, *attr.get()));
jcp.normalize_variance = false;
mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx2>(jcp));
if (normalizeVariance_) {
jcp.normalize_variance = true;
mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::avx2>(jcp));
}
} else if (mayiuse(cpu::x64::sse41)) {
mvn_kernel.reset(new jit_uni_mvn_kernel_f32<cpu::x64::sse41>(jcp, *attr.get()));
jcp.normalize_variance = false;
mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::sse41>(jcp));
if (normalizeVariance_) {
jcp.normalize_variance = true;
mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32<cpu::x64::sse41>(jcp));
}
postOpsDataPtrs.clear();
auto &postOps = (*key.attr.get()).post_ops_;
for (int i = 0; i < postOps.len(); ++i) {
auto &postOp = postOps.entry_[i];
if (postOp.is_quantization()) {
auto &data = postOp.quantization.data;
postOpsDataPtrs.insert(postOpsDataPtrs.end(), std::begin(data), std::end(data));
memset(data, 0, sizeof(data));
} else if (postOp.is_depthwise()) {
auto &weights = postOp.depthwise.weights_data;
auto &biases = postOp.depthwise.biases_data;
postOpsDataPtrs.push_back(weights);
postOpsDataPtrs.push_back(biases);
weights = 0;
biases = 0;
}
}
if (mvn_kernel)
mvn_kernel->create_ker();
if (mvn_mean_kernel)
mvn_mean_kernel->create_ker();
if (mvn_variance_kernel)
mvn_variance_kernel->create_ker();
auto builder = [&](const MVNKey& key) -> std::shared_ptr<MVNExecutor> {
std::shared_ptr<MVNExecutor> executor;
if (mayiuse(cpu::x64::sse41)) {
executor = std::make_shared<MVNJitExecutor>(key.mvnAttrs, key.attr);
} else {
executor = std::make_shared<MVNRefExecutor>(key.mvnAttrs);
}
return executor;
};
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, builder);
execPtr = result.first;
}
void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) {
@ -854,26 +967,26 @@ void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) {
// for 1 and 2 rank, if initAcrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure.
// otherwise there are not enough data in spatial dimension to process in one kernel.
case 1 : // C
if (initAcrossChannels_) {
shape5D = std::make_tuple(1, 1, 1, 1, shape[0]);
execAcrossChannels_ = false;
if (mvnAttrs.initAcrossChannels_) {
mvnAttrs.shape5D = std::make_tuple(1, 1, 1, 1, shape[0]);
mvnAttrs.execAcrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(1, shape[0], 1, 1, 1);
mvnAttrs.shape5D = std::make_tuple(1, shape[0], 1, 1, 1);
break;
}
case 2 : // NC
if (initAcrossChannels_) {
shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1);
execAcrossChannels_ = false;
if (mvnAttrs.initAcrossChannels_) {
mvnAttrs.shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1);
mvnAttrs.execAcrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1);
mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1);
break;
}
case 3 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; }
case 4 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; }
case 5 : { shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; }
case 3 : { mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; }
case 4 : { mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; }
case 5 : { mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; }
default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << shape.size(); }
}
}
@ -881,7 +994,7 @@ void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) {
void MKLDNNMVNNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) {
mkldnn::post_ops ops;
VectorDims postOpDims(5);
std::tie(postOpDims[0], postOpDims[1], postOpDims[2], postOpDims[3], postOpDims[4]) = shape5D;
std::tie(postOpDims[0], postOpDims[1], postOpDims[2], postOpDims[3], postOpDims[4]) = mvnAttrs.shape5D;
for (auto &node : fusedWith) {
auto* fakeQuantizeNode = dynamic_cast<MKLDNNFakeQuantizeNode *>(node.get());
if (fakeQuantizeNode) {
@ -904,27 +1017,18 @@ void MKLDNNMVNNode::executeDynamicImpl(mkldnn::stream strm) {
}
void MKLDNNMVNNode::execute(mkldnn::stream strm) {
if (!execPtr) {
IE_THROW() << "Can't execute MVN node. Primitive didn't created";
}
auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
uint8_t *dst_data = reinterpret_cast<uint8_t*>(dstMemPtr->GetPtr());
uint8_t *src_data = reinterpret_cast<uint8_t*>(srcMemPtr->GetPtr());
if (mayiuse(cpu::x64::sse41)) {
if (!mvn_mean_kernel || (normalizeVariance_ && !mvn_variance_kernel) || !mvn_kernel) {
IE_THROW() << "MVN layer with name '" << getName() << "' doesn't create kernel to execute on sse41 above platform.";
}
if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::ncsp)) {
mvn_pln(src_data, dst_data);
} else {
mvn_blk(src_data, dst_data);
}
} else {
mvn_ref(src_data, dst_data);
}
execPtr->exec(src_data, dst_data, postOpsDataPtrs.data());
}
void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
void MKLDNNMVNNode::MVNJitExecutor::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_) {
size_t blk_size = 1; // blk size in vmm
if (mayiuse(cpu::x64::avx512_common)) {
blk_size = 16;
@ -935,7 +1039,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
}
size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0;
std::tie(N, C, D, H, W) = shape5D;
std::tie(N, C, D, H, W) = mvnAttrs.shape5D;
size_t C1 = H * W;
size_t C2 = C1 * D;
@ -946,7 +1050,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
for (size_t b = 0lu; b < N; b++) {
size_t cb = b * C3;
if (execAcrossChannels_) {
if (mvnAttrs.execAcrossChannels_) {
// Calculate mean value for one instance in batch
// Parallel sum for each channel
float C3inv = 1.f / static_cast<float>(C3);
@ -959,6 +1063,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
arg.sum = static_cast<float*>(&mean_internal);
arg.src_stride = src_stride_size;
arg.work_amount = static_cast<size_t>(C2 / blk_size); // for vector part
arg.post_op_data = post_ops_data_;
(*mvn_mean_kernel)(&arg);
return mean_internal;
});
@ -967,7 +1072,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
// calculate variance value for one instance in batch
// parallel sum for each channel
if (normalizeVariance_) {
if (mvnAttrs.normalizeVariance_) {
float variance_temp = 0.0f;
variance_temp = parallel_sum(C, variance_temp, [&](size_t c)->float {
float variance_internal = 0.0f;
@ -978,15 +1083,16 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
arg.variance = static_cast<float*>(&variance_internal);
arg.src_stride = src_stride_size;
arg.work_amount = static_cast<size_t>(C2 / blk_size); // vector part
arg.post_op_data = post_ops_data_;
(*mvn_variance_kernel)(&arg);
return variance_internal;
});
float variance = 1.f;
if (epsMode_ == INSIDE_SQRT)
variance /= sqrtf(variance_temp * C3inv + epsValue_);
else if (epsMode_ == OUTSIDE_SQRT)
variance /= sqrtf(variance_temp * C3inv) + epsValue_;
if (mvnAttrs.epsMode_ == INSIDE_SQRT)
variance /= sqrtf(variance_temp * C3inv + mvnAttrs.epsValue_);
else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT)
variance /= sqrtf(variance_temp * C3inv) + mvnAttrs.epsValue_;
// mvn for one instance in batch
parallel_for(C, [&](int c) {
size_t cc = cb + c * C2;
@ -999,6 +1105,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
arg.dst_stride = dst_stride_size;
arg.work_amount = static_cast<size_t>(C2 / blk_size); // work amount for vector part
arg.oc_off = sizeof(float) * c;
arg.post_op_data = post_ops_data_;
(*mvn_kernel)(&arg);
});
} else {
@ -1013,6 +1120,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
arg.dst_stride = dst_stride_size;
arg.work_amount = static_cast<size_t>(C2 / blk_size);
arg.oc_off = sizeof(float) * c;
arg.post_op_data = post_ops_data_;
(*mvn_kernel)(&arg);
});
}
@ -1031,21 +1139,22 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
arg.dst_stride = dst_stride_size;
arg.work_amount = static_cast<size_t>(C2 / blk_size);
arg.oc_off = static_cast<size_t>(c * sizeof(float));
arg.post_op_data = post_ops_data_;
(*mvn_mean_kernel)(&arg);
mean *= C2inv;
if (normalizeVariance_) {
if (mvnAttrs.normalizeVariance_) {
// variance for this channel
float variance = 0.f;
arg.mean = static_cast<float*>(&mean);
arg.variance = static_cast<float*>(&variance);
(*mvn_variance_kernel)(&arg);
if (epsMode_ == INSIDE_SQRT)
variance = 1.f / sqrtf(variance * C2inv + epsValue_);
else if (epsMode_ == OUTSIDE_SQRT)
variance = 1.f / (sqrtf(variance * C2inv) + epsValue_);
if (mvnAttrs.epsMode_ == INSIDE_SQRT)
variance = 1.f / sqrtf(variance * C2inv + mvnAttrs.epsValue_);
else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT)
variance = 1.f / (sqrtf(variance * C2inv) + mvnAttrs.epsValue_);
// mvn for this channel
(*mvn_kernel)(&arg);
@ -1059,11 +1168,11 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) {
}
}
void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
void MKLDNNMVNNode::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
const float *src_data_ptr = reinterpret_cast<const float *>(src_data);
float *dst_data_ptr = reinterpret_cast<float *>(dst_data);
size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0;
std::tie(N, C, D, H, W) = shape5D;
std::tie(N, C, D, H, W) = mvnAttrs.shape5D;
size_t C1 = H * W;
size_t C2 = C1 * D;
@ -1071,7 +1180,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
for (size_t b = 0lu; b < N; b++) {
size_t cb = b * C3;
if (execAcrossChannels_) {
if (mvnAttrs.execAcrossChannels_) {
// Parallel sum for each channel for mean
float C3inv = 1.f / static_cast<float>(C3);
float mean_temp = 0.0f;
@ -1087,7 +1196,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
float mean = mean_temp * C3inv;
if (normalizeVariance_) {
if (mvnAttrs.normalizeVariance_) {
// parallel sum for each channel for variance
float variance_temp = 0.0f;
variance_temp = parallel_sum(C, variance_temp, [&](size_t c)->float {
@ -1100,10 +1209,10 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
});
float variance = 1.f;
if (epsMode_ == INSIDE_SQRT)
variance = 1.f / sqrtf(variance_temp * C3inv + epsValue_);
else if (epsMode_ == OUTSIDE_SQRT)
variance = 1.f / (sqrtf(variance_temp * C3inv) + epsValue_);
if (mvnAttrs.epsMode_ == INSIDE_SQRT)
variance = 1.f / sqrtf(variance_temp * C3inv + mvnAttrs.epsValue_);
else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT)
variance = 1.f / (sqrtf(variance_temp * C3inv) + mvnAttrs.epsValue_);
parallel_for(C, [&](int c) {
size_t cc = cb + c * C2;
@ -1130,17 +1239,17 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
}
mean *= C2inv;
if (normalizeVariance_) {
if (mvnAttrs.normalizeVariance_) {
// variance for this channel
float variance = 0.f;
for (size_t sp = 0lu; sp < C2; sp++) {
variance += (src_data_ptr[cc + sp] - mean) * (src_data_ptr[cc + sp] - mean);
}
if (epsMode_ == INSIDE_SQRT)
variance = 1.f / sqrtf(variance * C2inv + epsValue_);
else if (epsMode_ == OUTSIDE_SQRT)
variance = 1.f / (sqrtf(variance * C2inv) + epsValue_);
if (mvnAttrs.epsMode_ == INSIDE_SQRT)
variance = 1.f / sqrtf(variance * C2inv + mvnAttrs.epsValue_);
else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT)
variance = 1.f / (sqrtf(variance * C2inv) + mvnAttrs.epsValue_);
// mvn for this channel
for (size_t sp = 0lu; sp < C2; sp++) {
@ -1157,7 +1266,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
}
}
void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_) {
size_t blk_size = 1; // channel blk for memory layout
if (mayiuse(cpu::x64::avx512_common)) {
blk_size = 16;
@ -1166,34 +1275,32 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
}
size_t N = 1; size_t C = 1; size_t D = 1; size_t H = 1; size_t W = 1;
std::tie(N, C, D, H, W) = shape5D;
bool is_nhwc = getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc);
std::tie(N, C, D, H, W) = mvnAttrs.shape5D;
size_t CB = div_up(C, blk_size);
size_t C0 = is_nhwc ? W * C : W * blk_size;
size_t C0 = mvnAttrs.is_nhwc ? W * C : W * blk_size;
size_t C1 = C0 * H;
size_t C2 = C1 * D;
size_t C3 = C2 * CB;
size_t C5 = C * D * H * W;
size_t threads_num = parallel_get_num_threads();
size_t aux_buffer_size = execAcrossChannels_ ? blk_size : rnd_up(C, blk_size);
size_t aux_buffer_size = mvnAttrs.execAcrossChannels_ ? blk_size : rnd_up(C, blk_size);
std::vector<float> mean_buffer(aux_buffer_size * threads_num);
std::vector<float> variance_buffer(aux_buffer_size * threads_num);
size_t src_stride_size = is_nhwc ? static_cast<size_t>(C * src_data_size) : static_cast<size_t>(blk_size * src_data_size);
size_t dst_stride_size = is_nhwc ? static_cast<size_t>(C * dst_data_size) : static_cast<size_t>(blk_size * dst_data_size);
size_t src_stride_size = mvnAttrs.is_nhwc ? static_cast<size_t>(C * src_data_size) : static_cast<size_t>(blk_size * src_data_size);
size_t dst_stride_size = mvnAttrs.is_nhwc ? static_cast<size_t>(C * dst_data_size) : static_cast<size_t>(blk_size * dst_data_size);
for (size_t b = 0lu; b < N; b++) {
size_t b_offset = is_nhwc ? b * C5 : b * C3;
if (execAcrossChannels_) {
size_t b_offset = mvnAttrs.is_nhwc ? b * C5 : b * C3;
if (mvnAttrs.execAcrossChannels_) {
// mean for this instance in batch
float C5inv = 1.f / static_cast<float>(C5);
float mean_temp = 0.0f;
mean_temp = parallel_sum3d(CB, D, H, mean_temp, [&](size_t cb, size_t d, size_t h)->float {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
float mean_internal = 0.0f;
@ -1225,11 +1332,11 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
});
float mean = mean_temp * C5inv;
if (normalizeVariance_) {
if (mvnAttrs.normalizeVariance_) {
// variance: sum((x-mean)*(x-mean)) for one instance in batch
float variance_temp = 0.0f;
variance_temp = parallel_sum3d(CB, D, H, variance_temp, [&](size_t cb, size_t d, size_t h)->float {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
float variance_internal = 0.0f;
@ -1244,6 +1351,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
arg.src_stride = src_stride_size;
arg.work_amount = static_cast<size_t>(W);
arg.oc_off = cb * blk_size * sizeof(float);
arg.post_op_data = post_ops_data_;
(*mvn_variance_kernel)(&arg);
size_t min_cb = (std::min)(blk_size, C - cb * blk_size);
@ -1253,13 +1361,13 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
});
float variance = 1.f;
if (epsMode_ == INSIDE_SQRT)
variance /= sqrtf(variance_temp * C5inv + epsValue_);
else if (epsMode_ == OUTSIDE_SQRT)
variance /= sqrtf(variance_temp * C5inv) + epsValue_;
if (mvnAttrs.epsMode_ == INSIDE_SQRT)
variance /= sqrtf(variance_temp * C5inv + mvnAttrs.epsValue_);
else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT)
variance /= sqrtf(variance_temp * C5inv) + mvnAttrs.epsValue_;
// mvn for one instance in batch
parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto arg = jit_mvn_call_args();
arg.src = src_data + src_offset * src_data_size;
@ -1270,12 +1378,13 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
arg.dst_stride = dst_stride_size;
arg.work_amount = static_cast<size_t>(W);
arg.oc_off = cb * blk_size * sizeof(float);
arg.post_op_data = post_ops_data_;
(*mvn_kernel)(&arg);
});
} else {
// mvn for one instance in batch
parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto arg = jit_mvn_call_args();
arg.src = src_data + src_offset * src_data_size;
@ -1285,6 +1394,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
arg.dst_stride = dst_stride_size;
arg.work_amount = static_cast<size_t>(W);
arg.oc_off = cb * blk_size * sizeof(float);
arg.post_op_data = post_ops_data_;
(*mvn_kernel)(&arg);
});
}
@ -1297,7 +1407,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
// keep the compute order the same as planar
parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb + aux_buffer_size * thr_idx];
@ -1307,6 +1417,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
arg.src_stride = src_stride_size;
arg.work_amount = static_cast<size_t>(W);
arg.oc_off = cb * blk_size * sizeof(float);
arg.post_op_data = post_ops_data_;
(*mvn_mean_kernel)(&arg);
}
});
@ -1318,13 +1429,13 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
for (size_t c = 0; c < C; c++)
mean_buffer[c] *= size_inv;
if (normalizeVariance_) {
if (mvnAttrs.normalizeVariance_) {
for (int i = 0; i < variance_buffer.size(); i++)
variance_buffer[i] = 0.f;
parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
auto variance_buffer_ptr = &variance_buffer[blk_size * cb + aux_buffer_size * thr_idx];
@ -1336,6 +1447,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
arg.src_stride = src_stride_size;
arg.work_amount = static_cast<size_t>(W);
arg.oc_off = cb * blk_size * sizeof(float);
arg.post_op_data = post_ops_data_;
(*mvn_variance_kernel)(&arg);
}
});
@ -1344,15 +1456,15 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
variance_buffer[c] += variance_buffer[c + aux_buffer_size * i];
}
for (size_t c = 0; c < C; c++) {
if (epsMode_ == INSIDE_SQRT)
variance_buffer[c] = 1.f / sqrtf(variance_buffer[c] * size_inv + epsValue_);
else if (epsMode_ == OUTSIDE_SQRT)
variance_buffer[c] = 1.f / (sqrtf(variance_buffer[c] * size_inv) + epsValue_);
if (mvnAttrs.epsMode_ == INSIDE_SQRT)
variance_buffer[c] = 1.f / sqrtf(variance_buffer[c] * size_inv + mvnAttrs.epsValue_);
else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT)
variance_buffer[c] = 1.f / (sqrtf(variance_buffer[c] * size_inv) + mvnAttrs.epsValue_);
}
parallel_for2d(D, H, [&](size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
auto variance_buffer_ptr = &variance_buffer[blk_size * cb];
@ -1366,6 +1478,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
arg.dst_stride = dst_stride_size;
arg.work_amount = static_cast<size_t>(W);
arg.oc_off = cb * blk_size * sizeof(float);
arg.post_op_data = post_ops_data_;
(*mvn_kernel)(&arg);
}
});
@ -1373,7 +1486,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
// normalizeVariance_ == false
parallel_for2d(D, H, [&](size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size
: b_offset + cb * C2 + d * C1 + h * C0;
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
@ -1385,6 +1498,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) {
arg.dst_stride = dst_stride_size;
arg.work_amount = static_cast<size_t>(W);
arg.oc_off = cb * blk_size * sizeof(float);
arg.post_op_data = post_ops_data_;
(*mvn_kernel)(&arg);
}
});
@ -1404,7 +1518,7 @@ bool MKLDNNMVNNode::canFuse(const MKLDNNNodePtr& node) const {
EltwiseSwish, EltwiseHswish, EltwiseMish, EltwiseHsigmoid, EltwiseRoundHalfToEven,
EltwiseRoundHalfAwayFromZero, EltwiseAbs, EltwiseSqrt, EltwiseSoftRelu);
if ((inputRank == 1 && !unaryEltwise) ||
(inputRank == 2 && !unaryEltwise && initAcrossChannels_)) {
(inputRank == 2 && !unaryEltwise && mvnAttrs.initAcrossChannels_)) {
return false;
}

View File

@ -35,6 +35,7 @@ struct jit_mvn_call_args {
size_t dst_stride;
size_t work_amount;
size_t oc_off;
const void* post_op_data;
};
struct jit_uni_mvn_mean_variance_kernel {
@ -85,50 +86,82 @@ public:
}
inline bool getAcrossChannels() const {
return initAcrossChannels_;
return mvnAttrs.initAcrossChannels_;
}
inline bool getNormalizeVariance() const {
return normalizeVariance_;
return mvnAttrs.normalizeVariance_;
}
bool canFuse(const MKLDNNNodePtr& node) const override;
void prepareParams() override;
private:
void mvn_pln(const uint8_t *src_data, uint8_t *dst_data);
void mvn_blk(const uint8_t *src_data, uint8_t *dst_data);
void mvn_ref(const uint8_t *src_data, uint8_t *dst_data);
void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false);
void transformTo5DCase(const InferenceEngine::SizeVector& shape);
std::tuple<size_t, size_t, size_t, size_t, size_t> shape5D;
bool initAcrossChannels_ = false;
bool execAcrossChannels_ = false;
bool normalizeVariance_ = true;
float epsValue_ = 1e-9f;
// Defines way to add epsilon: inside sqrt or outside.
enum MVNEpsMode {
INSIDE_SQRT,
OUTSIDE_SQRT
};
MVNEpsMode epsMode_;
struct MVNAttrs {
bool planar_layout;
std::tuple<size_t, size_t, size_t, size_t, size_t> shape5D;
bool initAcrossChannels_;
bool execAcrossChannels_;
bool normalizeVariance_;
float epsValue_;
MVNEpsMode epsMode_;
bool is_nhwc;
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
};
InferenceEngine::Precision input_prec, output_prec;
size_t src_data_size = 0;
size_t dst_data_size = 0;
private:
void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false);
mkldnn::primitive_attr attr;
void transformTo5DCase(const InferenceEngine::SizeVector& shape);
std::shared_ptr<jit_uni_mvn_mean_variance_kernel> mvn_mean_kernel;
std::shared_ptr<jit_uni_mvn_mean_variance_kernel> mvn_variance_kernel;
std::shared_ptr<jit_uni_mvn_kernel> mvn_kernel;
std::vector<const void*> postOpsDataPtrs;
MVNAttrs mvnAttrs;
class MVNExecutor {
public:
MVNExecutor(const MVNAttrs& mvnAttrs);
virtual void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) = 0;
virtual ~MVNExecutor() = default;
protected:
MVNAttrs mvnAttrs;
size_t src_data_size = 0;
size_t dst_data_size = 0;
};
std::shared_ptr<MVNExecutor> execPtr = nullptr;
class MVNJitExecutor : public MVNExecutor {
public:
MVNJitExecutor(const MVNAttrs& mvnAttrs,
const mkldnn::primitive_attr &attr);
void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) override;
private:
void mvn_pln(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_);
void mvn_blk(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_);
std::shared_ptr<jit_uni_mvn_mean_variance_kernel> mvn_mean_kernel;
std::shared_ptr<jit_uni_mvn_mean_variance_kernel> mvn_variance_kernel;
std::shared_ptr<jit_uni_mvn_kernel> mvn_kernel;
};
class MVNRefExecutor : public MVNExecutor {
public:
MVNRefExecutor(const MVNAttrs& mvnAttrs);
void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) override;
private:
void mvn_ref(const uint8_t *in_ptr_, uint8_t *out_ptr_);
};
};
} // namespace MKLDNNPlugin

View File

@ -1,406 +1,504 @@
//// Copyright (C) 2018-2022 Intel Corporation
//// SPDX-License-Identifier: Apache-2.0
////
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
//#include <shared_test_classes/single_layer/mvn.hpp>
//#include "ngraph_functions/builders.hpp"
//#include "test_utils/cpu_test_utils.hpp"
//#include "test_utils/fusing_test_utils.hpp"
//
//using namespace InferenceEngine;
//using namespace CPUTestUtils;
//
//namespace CPULayerTestsDefinitions {
//
//using basicCpuMvnParams = std::tuple<
// std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>>, // Input shapes
// InferenceEngine::Precision, // Input precision
// ngraph::AxisSet, // Reduction axes
// bool, // Across channels
// bool, // Normalize variance
// double>; // Epsilon
//
//typedef std::tuple<
// basicCpuMvnParams,
// CPUSpecificParams,
// fusingSpecificParams,
// Precision, // CNNNetwork input precision
// Precision> // CNNNetwork output precision
//MvnLayerCPUTestParamSet;
//
//class MvnLayerCPUTest : public testing::WithParamInterface<MvnLayerCPUTestParamSet>,
// virtual public LayerTestsUtils::LayerTestsCommon, public CpuTestWithFusing {
//public:
// static std::string getTestCaseName(testing::TestParamInfo<MvnLayerCPUTestParamSet> obj) {
// basicCpuMvnParams basicParamsSet;
// CPUSpecificParams cpuParams;
// fusingSpecificParams fusingParams;
// Precision inputPrecision, outputPrecision;
// std::tie(basicParamsSet, cpuParams, fusingParams, inputPrecision, outputPrecision) = obj.param;
//
// std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>> inputShapes;
// InferenceEngine::Precision netPrecision;
// ngraph::AxisSet axes;
// bool acrossChanels, normalizeVariance;
// double eps;
// std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet;
//
// std::ostringstream result;
// if (!inputShapes.first.empty()) {
// result << "IS=" << CommonTestUtils::partialShape2str(inputShapes.first) << "_";
// }
// result << "TS=";
// for (const auto& shape : inputShapes.second) {
// result << "(" << CommonTestUtils::vec2str(shape) << ")_";
// }
// result << "Precision=" << netPrecision.name() << "_";
// if (!axes.empty()) {
// result << "ReductionAccess=" << CommonTestUtils::vec2str(axes.to_vector()) << "_";
// } else {
// result << "AcrossChannels=" << (acrossChanels ? "TRUE" : "FALSE") << "_";
// }
// result << "NormalizeVariance=" << (normalizeVariance ? "TRUE" : "FALSE") << "_";
// result << "Epsilon=" << eps;
// result << "_" << "CNNInpPrc=" << inputPrecision.name();
// result << "_" << "CNNOutPrc=" << outputPrecision.name();
//
// result << CPUTestsBase::getTestCaseName(cpuParams);
//
// result << CpuTestWithFusing::getTestCaseName(fusingParams);
//
// return result.str();
// }
//protected:
// void SetUp() override {
// targetDevice = CommonTestUtils::DEVICE_CPU;
//
// basicCpuMvnParams basicParamsSet;
// CPUSpecificParams cpuParams;
// fusingSpecificParams fusingParams;
// std::tie(basicParamsSet, cpuParams, fusingParams, inPrc, outPrc) = this->GetParam();
//
// std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
// std::tie(postOpMgrPtr, fusedOps) = fusingParams;
//
// std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>> inputShapes;
// InferenceEngine::Precision netPrecision;
// ngraph::AxisSet axes;
// bool acrossChanels, normalizeVariance;
// double eps;
// std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet;
//
// for (size_t i = 0; i < inputShapes.second.size(); i++) {
// targetStaticShapes.push_back({inputShapes.second[i]});
// }
// inputDynamicShapes = inputShapes.first;
//
// auto netPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
// auto param = ngraph::builder::makeParams(netPrc, {targetStaticShapes[0].front()});
// auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(param));
// auto mvn = ngraph::builder::makeMVN(paramOuts[0], acrossChanels, normalizeVariance, eps);
// if (!axes.empty()) {
// mvn = ngraph::builder::makeMVN(paramOuts[0], axes, normalizeVariance, eps);
// }
//
// selectedType = getPrimitiveType() + "_" + inPrc.name();
//
// threshold = 0.015f;
// function = makeNgraphFunction(netPrc, param, mvn, "mvn");
// }
//};
//
//TEST_P(MvnLayerCPUTest, CompareWithRefs) {
// SKIP_IF_CURRENT_TEST_IS_DISABLED()
//
// Run();
// CheckPluginRelatedResults(executableNetwork, "MVN");
//}
//
//namespace {
//
//const std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>>> inputShapes_1D = {
// { {}, {{5}}},
// { {}, {{16}}},
// {
// // dynamic
// {{-1}},
// // target
// {
// {2},
// {16},
// {1}
// }
// },
// {
// // dynamic
// {{{1, 20}}},
// // target
// {
// {1},
// {16},
// {4}
// }
// }
//};
//
//const std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>>> inputShapes_2D = {
// { {}, {{1, 32}}},
// { {}, {{16, 64}}},
// {
// // dynamic
// {{-1, -1}},
// // target
// {
// {2, 16},
// {4, 16},
// {1, 16}
// }
// },
// {
// // dynamic
// {{{1, 5}, {1, 20}}},
// // target
// {
// {1, 1},
// {2, 16},
// {4, 16}
// }
// }
//};
//
//const std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>>> inputShapes_3D = {
// { {}, {{1, 32, 17}}},
// { {}, {{1, 37, 9}}},
// { {}, {{1, 16, 4}}},
// {
// // dynamic
// {{-1, -1, -1}},
// // target
// {
// {2, 16, 6},
// {4, 16, 2},
// {1, 16, 4}
// }
// },
// {
// // dynamic
// {{{1, 5}, {1, 20}, {1, 7}}},
// // target
// {
// {1, 1, 1},
// {2, 16, 6},
// {4, 16, 2}
// }
// }
//};
//
//const std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>>> inputShapes_4D = {
// { {}, {{1, 16, 5, 8}}},
// { {}, {{2, 19, 5, 10}}},
// { {}, {{7, 32, 2, 8}}},
// { {}, {{5, 8, 3, 5}}},
// { {}, {{1, 2, 7, 5}}},
// { {}, {{1, 4, 5, 5}}},
// { {}, {{1, 7, 3, 5}}},
// { {}, {{1, 15, 9, 5}}},
// { {}, {{4, 41, 6, 9}}},
// {
// // dynamic
// {{-1, -1, -1, -1}},
// // target
// {
// {2, 16, 10, 6},
// {4, 16, 2, 2},
// {1, 16, 8, 4}
// }
// },
// {
// // dynamic
// {{{1, 5}, {1, 20}, {1, 10}, {1, 7}}},
// // target
// {
// {1, 1, 1, 1},
// {2, 16, 10, 6},
// {4, 16, 2, 2}
// }
// }
//};
//
//const std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>>> inputShapes_5D = {
// { {}, {{1, 32, 8, 1, 6}}},
// { {}, {{1, 9, 1, 15, 9}}},
// { {}, {{6, 64, 6, 1, 18}}},
// { {}, {{2, 31, 2, 9, 1}}},
// { {}, {{10, 16, 5, 10, 6}}},
// {
// // dynamic
// {{-1, -1, -1, -1, -1}},
// // target
// {
// {2, 16, 5, 10, 6},
// {4, 16, 7, 2, 2},
// {1, 16, 11, 8, 4}
// }
// },
// {
// // dynamic
// {{{1, 5}, {1, 20}, {1, 7}, {1, 10}, {1, 7}}},
// // target
// {
// {1, 1, 1, 1, 1},
// {2, 16, 5, 10, 6},
// {4, 16, 7, 2, 2}
// }
// }
//};
//
//const std::vector<bool> acrossChannels = {
// true,
// false
//};
//
//const std::vector<bool> normalizeVariance = {
// true,
// false
//};
//
//const std::vector<double> epsilon = {
// 0.000000001
//};
//
//const std::vector<ngraph::AxisSet> emptyReductionAxes = {{}};
//
//std::vector<Precision> inpPrc = {Precision::I8, Precision::BF16, Precision::FP32};
//std::vector<Precision> outPrc = {Precision::BF16, Precision::FP32};
//
//std::vector<CPUSpecificParams> cpuParams_4D = {
// CPUSpecificParams({nhwc}, {nhwc}, {}, {}),
// CPUSpecificParams({nChw16c}, {nChw16c}, {}, {}),
// CPUSpecificParams({nchw}, {nchw}, {}, {})
//};
//
//std::vector<CPUSpecificParams> cpuParams_5D = {
// CPUSpecificParams({ndhwc}, {ndhwc}, {}, {}),
// CPUSpecificParams({nCdhw16c}, {nCdhw16c}, {}, {}),
// CPUSpecificParams({ncdhw}, {ncdhw}, {}, {})
//};
//
//std::vector<fusingSpecificParams> fusingParamsSet {
// emptyFusingSpec,
// /* activations */
// fusingRelu,
// fusingElu,
// fusingTanh,
// fusingSwish,
// /* FQ */
// fusingFakeQuantizePerChannel,
// fusingFakeQuantizePerChannelRelu,
// fusingFakeQuantizePerTensorRelu,
// /* another patterns */
// fusingScaleShift,
// fusingAddPerTensor
//};
//
//const auto Mvn3D = ::testing::Combine(
// ::testing::Combine(
// ::testing::ValuesIn(inputShapes_3D),
// ::testing::Values(InferenceEngine::Precision::FP32),
// ::testing::ValuesIn(emptyReductionAxes),
// ::testing::ValuesIn(acrossChannels),
// ::testing::ValuesIn(normalizeVariance),
// ::testing::ValuesIn(epsilon)),
// ::testing::Values(emptyCPUSpec),
// ::testing::ValuesIn(fusingParamsSet),
// ::testing::ValuesIn(inpPrc),
// ::testing::ValuesIn(outPrc));
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn3D, MvnLayerCPUTest, Mvn3D, MvnLayerCPUTest::getTestCaseName);
//
//const auto Mvn4D = ::testing::Combine(
// ::testing::Combine(
// ::testing::ValuesIn(inputShapes_4D),
// ::testing::Values(InferenceEngine::Precision::FP32),
// ::testing::ValuesIn(emptyReductionAxes),
// ::testing::ValuesIn(acrossChannels),
// ::testing::ValuesIn(normalizeVariance),
// ::testing::ValuesIn(epsilon)),
// ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D)),
// ::testing::ValuesIn(fusingParamsSet),
// ::testing::ValuesIn(inpPrc),
// ::testing::ValuesIn(outPrc));
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn4D, MvnLayerCPUTest, Mvn4D, MvnLayerCPUTest::getTestCaseName);
//
//const auto Mvn5D = ::testing::Combine(
// ::testing::Combine(
// ::testing::ValuesIn(inputShapes_5D),
// ::testing::Values(InferenceEngine::Precision::FP32),
// ::testing::ValuesIn(emptyReductionAxes),
// ::testing::ValuesIn(acrossChannels),
// ::testing::ValuesIn(normalizeVariance),
// ::testing::ValuesIn(epsilon)),
// ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)),
// ::testing::ValuesIn(fusingParamsSet),
// ::testing::ValuesIn(inpPrc),
// ::testing::ValuesIn(outPrc));
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn5D, MvnLayerCPUTest, Mvn5D, MvnLayerCPUTest::getTestCaseName);
//
//// 1D 2D case
//std::vector<fusingSpecificParams> fusingUnaryEltwiseParamsSet {
// /* activations */
// fusingRelu,
// fusingElu,
// fusingTanh,
// fusingSwish,
//};
//
//const auto Mvn1D = ::testing::Combine(
// ::testing::Combine(
// ::testing::ValuesIn(inputShapes_1D),
// ::testing::Values(InferenceEngine::Precision::FP32),
// ::testing::ValuesIn(emptyReductionAxes),
// ::testing::ValuesIn(acrossChannels),
// ::testing::ValuesIn(normalizeVariance),
// ::testing::ValuesIn(epsilon)),
// ::testing::Values(emptyCPUSpec),
// ::testing::ValuesIn(fusingUnaryEltwiseParamsSet),
// ::testing::ValuesIn(inpPrc),
// ::testing::ValuesIn(outPrc));
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn1D, MvnLayerCPUTest, Mvn1D, MvnLayerCPUTest::getTestCaseName);
//
//// 2D no transformed
//const auto Mvn2D = ::testing::Combine(
// ::testing::Combine(
// ::testing::ValuesIn(inputShapes_2D),
// ::testing::Values(InferenceEngine::Precision::FP32),
// ::testing::ValuesIn(emptyReductionAxes),
// ::testing::Values(false),
// ::testing::ValuesIn(normalizeVariance),
// ::testing::ValuesIn(epsilon)),
// ::testing::Values(emptyCPUSpec),
// ::testing::ValuesIn(fusingParamsSet),
// ::testing::ValuesIn(inpPrc),
// ::testing::ValuesIn(outPrc));
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn2D, MvnLayerCPUTest, Mvn2D, MvnLayerCPUTest::getTestCaseName);
//
//// 2d transformed
//const auto Mvn2DTrans = ::testing::Combine(
// ::testing::Combine(
// ::testing::ValuesIn(inputShapes_2D),
// ::testing::Values(InferenceEngine::Precision::FP32),
// ::testing::ValuesIn(emptyReductionAxes),
// ::testing::Values(true),
// ::testing::ValuesIn(normalizeVariance),
// ::testing::ValuesIn(epsilon)),
// ::testing::Values(emptyCPUSpec),
// ::testing::ValuesIn(fusingUnaryEltwiseParamsSet),
// ::testing::ValuesIn(inpPrc),
// ::testing::ValuesIn(outPrc));
//
//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_MVN2DTrans, MvnLayerCPUTest, Mvn2DTrans, MvnLayerCPUTest::getTestCaseName);
//
//} // namespace
//} // namespace CPULayerTestsDefinitions
#include <shared_test_classes/single_layer/mvn.hpp>
#include "ngraph_functions/builders.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "test_utils/fusing_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov::test;
namespace CPULayerTestsDefinitions {
using basicCpuMvnParams = std::tuple<
InputShape, // Input shapes
ElementType, // Input precision
ngraph::AxisSet, // Reduction axes
bool, // Across channels
bool, // Normalize variance
double>; // Epsilon
using MvnLayerCPUTestParamSet = std::tuple<
basicCpuMvnParams,
CPUSpecificParams,
fusingSpecificParams,
ElementType, // CNNNetwork input precision
ElementType>; // CNNNetwork output precision
class MvnLayerCPUTest : public testing::WithParamInterface<MvnLayerCPUTestParamSet>,
virtual public SubgraphBaseTest, public CpuTestWithFusing {
public:
static std::string getTestCaseName(testing::TestParamInfo<MvnLayerCPUTestParamSet> obj) {
basicCpuMvnParams basicParamsSet;
CPUSpecificParams cpuParams;
fusingSpecificParams fusingParams;
ElementType inputPrecision, outputPrecision;
std::tie(basicParamsSet, cpuParams, fusingParams, inputPrecision, outputPrecision) = obj.param;
InputShape inputShapes;
ElementType netPrecision;
ngraph::AxisSet axes;
bool acrossChanels, normalizeVariance;
double eps;
std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet;
std::ostringstream result;
result << "IS=" << CommonTestUtils::partialShape2str({inputShapes.first}) << "_";
result << "TS=";
for (const auto& shape : inputShapes.second) {
result << "(" << CommonTestUtils::vec2str(shape) << ")_";
}
result << "Precision=" << netPrecision << "_";
if (!axes.empty()) {
result << "ReductionAccess=" << CommonTestUtils::vec2str(axes.to_vector()) << "_";
} else {
result << "AcrossChannels=" << (acrossChanels ? "TRUE" : "FALSE") << "_";
}
result << "NormalizeVariance=" << (normalizeVariance ? "TRUE" : "FALSE") << "_";
result << "Epsilon=" << eps;
result << "_" << "CNNInpPrc=" << inputPrecision;
result << "_" << "CNNOutPrc=" << outputPrecision;
result << CPUTestsBase::getTestCaseName(cpuParams);
result << CpuTestWithFusing::getTestCaseName(fusingParams);
return result.str();
}
protected:
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
basicCpuMvnParams basicParamsSet;
CPUSpecificParams cpuParams;
fusingSpecificParams fusingParams;
ElementType inPrc;
ElementType outPrc;
std::tie(basicParamsSet, cpuParams, fusingParams, inPrc, outPrc) = this->GetParam();
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
std::tie(postOpMgrPtr, fusedOps) = fusingParams;
InputShape inputShapes;
ElementType netPrecision;
ngraph::AxisSet axes;
bool acrossChanels, normalizeVariance;
double eps;
std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet;
init_input_shapes({inputShapes});
auto param = ngraph::builder::makeDynamicParams(netPrecision, inputDynamicShapes);
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(param));
auto mvn = ngraph::builder::makeMVN(paramOuts[0], acrossChanels, normalizeVariance, eps);
if (!axes.empty()) {
mvn = ngraph::builder::makeMVN(paramOuts[0], axes, normalizeVariance, eps);
}
selectedType = getPrimitiveType();
selectedType = makeSelectedTypeStr(selectedType, netPrecision);
rel_threshold = 0.015f;
function = makeNgraphFunction(netPrecision, param, mvn, "mvn");
}
};
TEST_P(MvnLayerCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
CheckPluginRelatedResults(executableNetwork, "MVN");
}
namespace {
const std::vector<InputShape> inputShapes_1D = {
{ {}, {{5}}},
{ {}, {{16}}},
{
// dynamic
{-1},
// target
{
{2},
{16},
{1},
{2}
}
},
{
// dynamic
{{1, 20}},
// target
{
{1},
{16},
{4},
{16}
}
}
};
const std::vector<InputShape> inputShapes_2D = {
{ {}, {{1, 32}}},
{ {}, {{16, 64}}},
{
// dynamic
{-1, -1},
// target
{
{2, 16},
{4, 16},
{1, 16},
{4, 16}
}
},
{
// dynamic
{{1, 5}, {1, 20}},
// target
{
{1, 1},
{2, 16},
{4, 16},
{2, 16}
}
}
};
const std::vector<InputShape> inputShapes_3D = {
{ {}, {{1, 32, 17}}},
{ {}, {{1, 37, 9}}},
{ {}, {{1, 16, 4}}},
{
// dynamic
{-1, -1, -1},
// target
{
{2, 16, 6},
{4, 16, 2},
{2, 16, 6},
{4, 16, 2}
}
},
{
// dynamic
{{1, 5}, {1, 20}, {1, 7}},
// target
{
{1, 1, 1},
{2, 16, 6},
{4, 16, 2},
{2, 16, 6}
}
}
};
const std::vector<InputShape> inputShapes_4D = {
{ {}, {{1, 16, 5, 8}}},
{ {}, {{2, 19, 5, 10}}},
{ {}, {{7, 32, 2, 8}}},
{ {}, {{5, 8, 3, 5}}},
{ {}, {{1, 2, 7, 5}}},
{ {}, {{1, 4, 5, 5}}},
{ {}, {{1, 7, 3, 5}}},
{ {}, {{1, 15, 9, 5}}},
{ {}, {{4, 41, 6, 9}}},
{
// dynamic
{-1, -1, -1, -1},
// target
{
{2, 16, 10, 6},
{4, 16, 2, 2},
{2, 16, 10, 6},
{4, 16, 2, 2}
}
},
{
// dynamic
{{1, 5}, {1, 20}, {1, 10}, {1, 7}},
// target
{
{1, 1, 1, 1},
{2, 16, 10, 6},
{4, 16, 2, 2},
{2, 16, 10, 6}
}
}
};
const std::vector<InputShape> inputShapes_5D = {
{ {}, {{1, 32, 8, 1, 6}}},
{ {}, {{1, 9, 1, 15, 9}}},
{ {}, {{6, 64, 6, 1, 18}}},
{ {}, {{2, 31, 2, 9, 1}}},
{ {}, {{10, 16, 5, 10, 6}}},
{
// dynamic
{-1, -1, -1, -1, -1},
// target
{
{2, 16, 5, 10, 6},
{4, 16, 7, 2, 2},
{2, 16, 5, 10, 6},
{4, 16, 7, 2, 2}
}
},
{
// dynamic
{{1, 5}, {1, 20}, {1, 7}, {1, 10}, {1, 7}},
// target
{
{1, 1, 1, 1, 1},
{2, 16, 5, 10, 6},
{4, 16, 7, 2, 2},
{2, 16, 5, 10, 6}
}
}
};
const std::vector<bool> acrossChannels = {
true,
false
};
const std::vector<bool> normalizeVariance = {
true,
false
};
const std::vector<double> epsilon = {
0.000000001
};
const std::vector<ngraph::AxisSet> emptyReductionAxes = {{}};
std::vector<ElementType> inpPrc = {ElementType::i8, ElementType::bf16, ElementType::f32};
std::vector<ElementType> outPrc = {ElementType::bf16, ElementType::f32};
std::vector<CPUSpecificParams> cpuParams_4D = {
CPUSpecificParams({nhwc}, {nhwc}, {}, {}),
CPUSpecificParams({nChw16c}, {nChw16c}, {}, {}),
CPUSpecificParams({nchw}, {nchw}, {}, {})
};
std::vector<CPUSpecificParams> cpuParams_5D = {
CPUSpecificParams({ndhwc}, {ndhwc}, {}, {}),
CPUSpecificParams({nCdhw16c}, {nCdhw16c}, {}, {}),
CPUSpecificParams({ncdhw}, {ncdhw}, {}, {})
};
std::vector<fusingSpecificParams> fusingParamsSet {
emptyFusingSpec,
/* activations */
fusingRelu,
fusingElu,
fusingTanh,
fusingSwish,
/* FQ */
fusingFakeQuantizePerTensorRelu,
/* another patterns */
fusingAddPerTensor
};
std::vector<fusingSpecificParams> fusingParamsSetStaticShape {
/* FQ */
fusingFakeQuantizePerChannel,
fusingFakeQuantizePerChannelRelu,
/* another patterns */
fusingScaleShift,
};
const auto Mvn3D = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapes_3D),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::ValuesIn(acrossChannels),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::Values(emptyCPUSpec),
::testing::ValuesIn(fusingParamsSet),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn3D, MvnLayerCPUTest, Mvn3D, MvnLayerCPUTest::getTestCaseName);
const auto Mvn4D = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapes_4D),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::ValuesIn(acrossChannels),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D)),
::testing::ValuesIn(fusingParamsSet),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn4D, MvnLayerCPUTest, Mvn4D, MvnLayerCPUTest::getTestCaseName);
const auto Mvn5D = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapes_5D),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::ValuesIn(acrossChannels),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)),
::testing::ValuesIn(fusingParamsSet),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn5D, MvnLayerCPUTest, Mvn5D, MvnLayerCPUTest::getTestCaseName);
// 1D 2D case
std::vector<fusingSpecificParams> fusingUnaryEltwiseParamsSet {
/* activations */
fusingRelu,
fusingElu,
fusingTanh,
fusingSwish,
};
const auto Mvn1D = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapes_1D),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::ValuesIn(acrossChannels),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::Values(emptyCPUSpec),
::testing::ValuesIn(fusingUnaryEltwiseParamsSet),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn1D, MvnLayerCPUTest, Mvn1D, MvnLayerCPUTest::getTestCaseName);
// 2D no transformed
const auto Mvn2D = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapes_2D),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::Values(false),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::Values(emptyCPUSpec),
::testing::ValuesIn(fusingParamsSet),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn2D, MvnLayerCPUTest, Mvn2D, MvnLayerCPUTest::getTestCaseName);
// 2d transformed
const auto Mvn2DTrans = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapes_2D),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::Values(true),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::Values(emptyCPUSpec),
::testing::ValuesIn(fusingUnaryEltwiseParamsSet),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn2DTrans, MvnLayerCPUTest, Mvn2DTrans, MvnLayerCPUTest::getTestCaseName);
// Static shape test for some specific fusing parameters in fusingParamsSetStaticShape
const std::vector<ov::Shape> inputShapesStatic_2D = {
{1},
{16},
{4}
};
const std::vector<ov::Shape> inputShapesStatic_3D = {
{2, 16, 6},
{4, 16, 2},
{1, 16, 4}
};
const std::vector<ov::Shape> inputShapesStatic_4D = {
{1, 7, 3, 5},
{1, 15, 9, 5},
{4, 41, 6, 9}
};
const std::vector<ov::Shape> inputShapesStatic_5D = {
{1, 32, 8, 1, 6},
{1, 9, 1, 15, 9},
{6, 64, 6, 1, 18}
};
const auto Mvn2DStatic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inputShapesStatic_2D),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::Values(false),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::Values(emptyCPUSpec),
::testing::ValuesIn(fusingParamsSetStaticShape),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
const auto Mvn3DStatic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic_3D)),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::ValuesIn(acrossChannels),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::Values(emptyCPUSpec),
::testing::ValuesIn(fusingParamsSetStaticShape),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn3D_Static, MvnLayerCPUTest, Mvn3DStatic, MvnLayerCPUTest::getTestCaseName);
const auto Mvn4DStatic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic_4D)),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::ValuesIn(acrossChannels),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D)),
::testing::ValuesIn(fusingParamsSetStaticShape),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn4D_Static, MvnLayerCPUTest, Mvn4DStatic, MvnLayerCPUTest::getTestCaseName);
const auto Mvn5DStatic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic_5D)),
::testing::Values(ElementType::f32),
::testing::ValuesIn(emptyReductionAxes),
::testing::ValuesIn(acrossChannels),
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(epsilon)),
::testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)),
::testing::ValuesIn(fusingParamsSetStaticShape),
::testing::ValuesIn(inpPrc),
::testing::ValuesIn(outPrc));
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn5D_Static, MvnLayerCPUTest, Mvn5DStatic, MvnLayerCPUTest::getTestCaseName);
} // namespace
} // namespace CPULayerTestsDefinitions

View File

@ -16,7 +16,7 @@ std::shared_ptr<ngraph::Node> makeMVN(const ngraph::Output<Node> &in,
// Ngraph MVN implementation implicitly adds 0th dimension to reduction axes set which is not valid behavior
ngraph::AxisSet axes;
const size_t startAxis = acrossChannels ? 1 : 2;
const size_t numOfDims = in.get_shape().size();
const size_t numOfDims = in.get_partial_shape().size();
for (size_t i = startAxis; i < numOfDims; i++)
axes.insert(i);
mvnNode->set_reduction_axes(axes);