[CPU]disable and cleanup interp and resample that are covered by interpolate (#3164)
* [BF16] Interpolate layer and test were updated for support BF16 Co-authored-by: alexey-varyzgin <alexey.varyzgin@intel.com>
This commit is contained in:
parent
a7ede592c3
commit
d35e3e806b
@ -36,7 +36,6 @@ set(LAYERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tensoriterator_node.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tile_node.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_resample_node.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_normalize_node.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_scatter_update_node.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_interpolate_node.cpp
|
||||
@ -93,7 +92,6 @@ set(LAYERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/unsqueeze.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/common/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/common/emitter.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/interp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/jit_eltwise_emitters.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/jit_mkldnn_emitters.cpp
|
||||
|
||||
|
@ -14,7 +14,7 @@ namespace MKLDNNPlugin {
|
||||
|
||||
class BF16Transformer {
|
||||
const InferenceEngine::details::caseless_set<std::string> _initbf16 =
|
||||
{ "convolution", "fullyconnected", "innerproduct", "gemm", "RegionYolo" };
|
||||
{ "convolution", "fullyconnected", "innerproduct", "gemm", "RegionYolo", "Interpolate" };
|
||||
const InferenceEngine::details::caseless_set<std::string> _complementbf16 =
|
||||
{ "relu", "tanh", "elu", "square", "abs", "sqrt", "linear", "bounded_relu", "soft_relu", "normalize",
|
||||
"sigmoid", "ReLU6", "not", "activation", "HSwish", "mish", "logistic", "mod", "resample",
|
||||
|
@ -15,7 +15,6 @@
|
||||
#include "nodes/mkldnn_quantize_node.h"
|
||||
#include "nodes/mkldnn_mvn_node.h"
|
||||
#include <nodes/mkldnn_permute_node.h>
|
||||
#include "nodes/mkldnn_resample_node.h"
|
||||
#include "nodes/mkldnn_interpolate_node.h"
|
||||
#include "nodes/mkldnn_input_node.h"
|
||||
|
||||
@ -123,9 +122,6 @@ void MKLDNNGraphOptimizer::ApplyCommonGraphOptimizations(MKLDNNGraph &graph) {
|
||||
FuseMVNAndSimpleOperation(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
FuseResampleAndSimpleOperation(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
FuseInterpolateAndSimpleOperation(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
@ -1491,74 +1487,6 @@ void MKLDNNGraphOptimizer::FuseMVNAndSimpleOperation(MKLDNNGraph &graph) {
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNGraphOptimizer::FuseResampleAndSimpleOperation(MKLDNNGraph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
auto isSutableParentNode = [](MKLDNNNodePtr node) {
|
||||
bool isSutableResample = (node->getType() == Resample) && (node->inDims[0].ndims() == 4 || node->inDims[0].ndims() == 5);
|
||||
|
||||
if (isSutableResample) {
|
||||
auto *resampleLayer = node->getCnnLayer().get();
|
||||
if (resampleLayer == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get Resample layer " << node->getName();
|
||||
|
||||
return node->getChildEdges().size() == 1 && resampleLayer->GetParamAsString("type") == "caffe.ResampleParameter.NEAREST";
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
auto isSutableChildNode = [](MKLDNNNodePtr node) {
|
||||
if (!node->getCnnLayer())
|
||||
return false;
|
||||
|
||||
if (node->getType() == Quantize) {
|
||||
auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(node.get());
|
||||
if (quantizeNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get quantize layer " << node->getName();
|
||||
return !quantizeNode->isBinarization();
|
||||
} else if (node->getType() == Eltwise) {
|
||||
auto *eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
|
||||
if (eltwiseNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get Eltwise node " << node->getName();
|
||||
return eltwiseNode->getOpType() == Relu ||
|
||||
eltwiseNode->getOpType() == MulAdd;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
auto parent = graphNodes.begin();
|
||||
while (parent != graphNodes.end()) {
|
||||
auto parentNode = *parent;
|
||||
if (!isSutableParentNode(parentNode)) {
|
||||
parent++;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto childNode = parentNode->getChildEdgeAt(0)->getChild();
|
||||
if (!isSutableChildNode(childNode)) {
|
||||
parent++;
|
||||
continue;
|
||||
}
|
||||
|
||||
parentNode->fuseWith(childNode);
|
||||
|
||||
if (childNode->getType() == Quantize || childNode->getType() == Eltwise) {
|
||||
auto parentEdges = childNode->parentEdges;
|
||||
for (auto &parentEdge : parentEdges) {
|
||||
auto p_edge = parentEdge.lock();
|
||||
if (p_edge->getParent()->getType() == Resample)
|
||||
continue;
|
||||
|
||||
removeEdge(graph, p_edge);
|
||||
}
|
||||
}
|
||||
|
||||
graph.DropNode(childNode);
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNGraphOptimizer::FuseInterpolateAndSimpleOperation(MKLDNNGraph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
|
@ -37,7 +37,6 @@ private:
|
||||
void FuseConvolutionSumAndConvolutionSumActivation(MKLDNNGraph &graph);
|
||||
#endif
|
||||
void FuseMVNAndSimpleOperation(MKLDNNGraph &graph);
|
||||
void FuseResampleAndSimpleOperation(MKLDNNGraph &graph);
|
||||
void FuseInterpolateAndSimpleOperation(MKLDNNGraph &graph);
|
||||
void FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph);
|
||||
void RemoveIdentityOperator(MKLDNNGraph& graph);
|
||||
|
@ -39,7 +39,6 @@
|
||||
#include <nodes/mkldnn_bin_conv_node.h>
|
||||
#include <nodes/mkldnn_def_conv_node.h>
|
||||
#include <nodes/mkldnn_mvn_node.h>
|
||||
#include <nodes/mkldnn_resample_node.h>
|
||||
#include <nodes/mkldnn_normalize_node.h>
|
||||
#include <nodes/mkldnn_reduce_node.h>
|
||||
#include <nodes/mkldnn_tensoriterator_node.h>
|
||||
@ -123,7 +122,6 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
|
||||
{ "Memory", MemoryOutput }, // for construction from layer ctor
|
||||
{ "Convert", Convert },
|
||||
{ "MVN", MVN},
|
||||
{ "Resample", Resample},
|
||||
{ "Normalize", Normalize},
|
||||
{ "ScatterUpdate", ScatterUpdate},
|
||||
{ "ScatterElementsUpdate", ScatterElementsUpdate},
|
||||
|
@ -66,7 +66,6 @@ enum Type {
|
||||
TensorIterator,
|
||||
Convert,
|
||||
MVN,
|
||||
Resample,
|
||||
Normalize,
|
||||
ScatterUpdate,
|
||||
ScatterElementsUpdate,
|
||||
@ -162,8 +161,6 @@ static std::string NameFromType(Type type) {
|
||||
return "TensorIterator";
|
||||
case Convert:
|
||||
return "Convert";
|
||||
case Resample:
|
||||
return "Resample";
|
||||
case Normalize:
|
||||
return "Normalize";
|
||||
case ScatterUpdate:
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_interpolate_to_interp_or_resample.hpp>
|
||||
#include <legacy/ngraph_ops/fully_connected.hpp>
|
||||
|
||||
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
|
||||
@ -52,6 +53,7 @@
|
||||
#include <transformations/op_conversions/rnn_cell_decomposition.hpp>
|
||||
#include <transformations/op_conversions/gru_cell_decomposition.hpp>
|
||||
#include <transformations/op_conversions/log_softmax_decomposition.hpp>
|
||||
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
@ -200,8 +202,10 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf)
|
||||
pass_config->disable<ngraph::pass::HSigmoidDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ConvertMod>();
|
||||
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
|
||||
|
||||
pass_config->enable<ngraph::pass::ConvertPadToGroupConvolution>();
|
||||
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
|
||||
|
||||
manager.run_passes(nGraphFunc);
|
||||
|
||||
|
@ -1,432 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "base.hpp"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include "ie_parallel.hpp"
|
||||
#include "jit_generator.hpp"
|
||||
|
||||
using namespace mkldnn::impl::cpu;
|
||||
using namespace mkldnn::impl::utils;
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace Extensions {
|
||||
namespace Cpu {
|
||||
|
||||
#define GET_OFF(field) offsetof(jit_args_interp, field)
|
||||
|
||||
struct jit_args_interp {
|
||||
const float *src00;
|
||||
const float *src01;
|
||||
const float *src10;
|
||||
const float *src11;
|
||||
float *dst;
|
||||
float *h_lambda0;
|
||||
float *h_lambda1;
|
||||
float *w_lambda0;
|
||||
float *w_lambda1;
|
||||
};
|
||||
|
||||
struct jit_uni_interp_kernel {
|
||||
void (*ker_)(const jit_args_interp *);
|
||||
|
||||
void operator()(const jit_args_interp *args) { assert(ker_); ker_(args); }
|
||||
|
||||
jit_uni_interp_kernel() : ker_(nullptr) {}
|
||||
virtual ~jit_uni_interp_kernel() {}
|
||||
};
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
struct jit_uni_interp_kernel_f32 : public jit_uni_interp_kernel, public jit_generator {
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_interp_kernel_f32)
|
||||
|
||||
jit_uni_interp_kernel_f32() : jit_uni_interp_kernel(), jit_generator() {
|
||||
this->preamble();
|
||||
|
||||
mov(reg_src00, ptr[reg_params + GET_OFF(src00)]);
|
||||
mov(reg_src01, ptr[reg_params + GET_OFF(src01)]);
|
||||
mov(reg_src10, ptr[reg_params + GET_OFF(src10)]);
|
||||
mov(reg_src11, ptr[reg_params + GET_OFF(src11)]);
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
|
||||
mov(reg_h_lambda0, ptr[reg_params + GET_OFF(h_lambda0)]);
|
||||
mov(reg_h_lambda1, ptr[reg_params + GET_OFF(h_lambda1)]);
|
||||
mov(reg_w_lambda0, ptr[reg_params + GET_OFF(w_lambda0)]);
|
||||
mov(reg_w_lambda1, ptr[reg_params + GET_OFF(w_lambda1)]);
|
||||
|
||||
uni_vmovups(vmm_src00, ptr[reg_src00]);
|
||||
uni_vmovups(vmm_src01, ptr[reg_src01]);
|
||||
uni_vmovups(vmm_src10, ptr[reg_src10]);
|
||||
uni_vmovups(vmm_src11, ptr[reg_src11]);
|
||||
|
||||
uni_vbroadcastss(vmm_h_lambda0, ptr[reg_h_lambda0]);
|
||||
uni_vbroadcastss(vmm_h_lambda1, ptr[reg_h_lambda1]);
|
||||
uni_vbroadcastss(vmm_w_lambda0, ptr[reg_w_lambda0]);
|
||||
uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]);
|
||||
|
||||
if (isa != sse42) {
|
||||
uni_vmulps(vmm_src01, vmm_src01, vmm_w_lambda0);
|
||||
uni_vmulps(vmm_src11, vmm_src11, vmm_w_lambda0);
|
||||
uni_vfmadd231ps(vmm_src01, vmm_w_lambda1, vmm_src00);
|
||||
uni_vfmadd231ps(vmm_src11, vmm_w_lambda1, vmm_src10);
|
||||
uni_vmulps(vmm_src01, vmm_src01, vmm_h_lambda1);
|
||||
uni_vfmadd231ps(vmm_src01, vmm_h_lambda0, vmm_src11);
|
||||
uni_vmovups(ptr[reg_dst], vmm_src01);
|
||||
} else {
|
||||
uni_vmulps(vmm_src01, vmm_src01, vmm_w_lambda0);
|
||||
uni_vmulps(vmm_src11, vmm_src11, vmm_w_lambda0);
|
||||
uni_vfmadd231ps(vmm_src01, vmm_w_lambda1, vmm_src00);
|
||||
// uni_vfmadd231ps affects XMM (vmm_w_lambda1) register. Need to initialize again.
|
||||
uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]);
|
||||
uni_vfmadd231ps(vmm_src11, vmm_w_lambda1, vmm_src10);
|
||||
uni_vmulps(vmm_src01, vmm_src01, vmm_h_lambda1);
|
||||
uni_vfmadd231ps(vmm_src01, vmm_h_lambda0, vmm_src11);
|
||||
uni_vmovups(ptr[reg_dst], vmm_src01);
|
||||
|
||||
// Next 4 elements
|
||||
size_t stride = 4 * sizeof(float);
|
||||
|
||||
add(reg_src00, stride);
|
||||
add(reg_src01, stride);
|
||||
add(reg_src10, stride);
|
||||
add(reg_src11, stride);
|
||||
add(reg_dst, stride);
|
||||
|
||||
uni_vmovups(vmm_src00, ptr[reg_src00]);
|
||||
uni_vmovups(vmm_src01, ptr[reg_src01]);
|
||||
uni_vmovups(vmm_src10, ptr[reg_src10]);
|
||||
uni_vmovups(vmm_src11, ptr[reg_src11]);
|
||||
|
||||
uni_vbroadcastss(vmm_h_lambda0, ptr[reg_h_lambda0]);
|
||||
uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]);
|
||||
|
||||
uni_vmulps(vmm_src01, vmm_src01, vmm_w_lambda0);
|
||||
uni_vmulps(vmm_src11, vmm_src11, vmm_w_lambda0);
|
||||
uni_vfmadd231ps(vmm_src01, vmm_w_lambda1, vmm_src00);
|
||||
uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]);
|
||||
uni_vfmadd231ps(vmm_src11, vmm_w_lambda1, vmm_src10);
|
||||
uni_vmulps(vmm_src01, vmm_src01, vmm_h_lambda1);
|
||||
uni_vfmadd231ps(vmm_src01, vmm_h_lambda0, vmm_src11);
|
||||
uni_vmovups(ptr[reg_dst], vmm_src01);
|
||||
}
|
||||
|
||||
this->postamble();
|
||||
ker_ = (decltype(ker_))this->getCode();
|
||||
}
|
||||
|
||||
private:
|
||||
using Vmm = typename conditional3<isa == sse42, Xbyak::Xmm, isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||
size_t vlen = cpu_isa_traits<isa>::vlen;
|
||||
|
||||
Xbyak::Reg64 reg_src00 = r8;
|
||||
Xbyak::Reg64 reg_src01 = r9;
|
||||
Xbyak::Reg64 reg_src10 = r10;
|
||||
Xbyak::Reg64 reg_src11 = r11;
|
||||
Xbyak::Reg64 reg_dst = rbp;
|
||||
Xbyak::Reg64 reg_h_lambda0 = r12;
|
||||
Xbyak::Reg64 reg_h_lambda1 = r13;
|
||||
Xbyak::Reg64 reg_w_lambda0 = r14;
|
||||
Xbyak::Reg64 reg_w_lambda1 = r15;
|
||||
Xbyak::Reg64 reg_params = abi_param1;
|
||||
|
||||
Vmm vmm_src00 = Vmm(0);
|
||||
Vmm vmm_src01 = Vmm(1);
|
||||
Vmm vmm_src10 = Vmm(2);
|
||||
Vmm vmm_src11 = Vmm(3);
|
||||
Vmm vmm_h_lambda0 = Vmm(4);
|
||||
Vmm vmm_h_lambda1 = Vmm(5);
|
||||
Vmm vmm_w_lambda0 = Vmm(6);
|
||||
Vmm vmm_w_lambda1 = Vmm(7);
|
||||
Vmm vmm_dst = Vmm(8);
|
||||
};
|
||||
|
||||
class InterpImpl: public ExtLayerBase {
|
||||
public:
|
||||
explicit InterpImpl(const CNNLayer* layer) {
|
||||
try {
|
||||
if (layer->insData.size() != 1 || layer->outData.empty())
|
||||
THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
|
||||
|
||||
auto inData = layer->insData[0].lock();
|
||||
if (inData == nullptr) {
|
||||
THROW_IE_EXCEPTION << "Layer '" << layer->name << "' has nullable input data.";
|
||||
}
|
||||
if (inData->getTensorDesc().getDims().size() != 4)
|
||||
THROW_IE_EXCEPTION << "Interp supports only 4d blobs!";
|
||||
|
||||
// We don't read other parameters since they are needed only for dst reshape in caffe
|
||||
pad_beg = layer->GetParamAsInt("pad_beg");
|
||||
pad_end = layer->GetParamAsInt("pad_end");
|
||||
align_corners = layer->GetParamAsBool("align_corners", true);
|
||||
|
||||
ConfLayout blk_layout;
|
||||
if (inData->getTensorDesc().getPrecision() == Precision::U8) {
|
||||
LayerConfig config;
|
||||
DataConfig dataConfigDct;
|
||||
dataConfigDct.desc = TensorDesc(Precision::U8, inData->getTensorDesc().getDims(), Layout::NCHW);
|
||||
config.inConfs.push_back(dataConfigDct);
|
||||
|
||||
DataConfig dataConfigOut;
|
||||
const SizeVector& out_dims = layer->outData[0]->getTensorDesc().getDims();
|
||||
SizeVector blocks = out_dims;
|
||||
SizeVector order(blocks.size());
|
||||
SizeVector dimOffsets(blocks.size());
|
||||
SizeVector strides(blocks.size());
|
||||
size_t offset((std::numeric_limits<size_t>::max)());
|
||||
for (size_t i = 0; i < order.size(); i++) {
|
||||
strides[i] = (std::numeric_limits<size_t>::max)();
|
||||
dimOffsets[i] = 0;
|
||||
order[i] = i;
|
||||
}
|
||||
dataConfigOut.desc = TensorDesc(Precision::FP32, out_dims, { blocks, order, offset, dimOffsets, strides });
|
||||
config.outConfs.push_back(dataConfigOut);
|
||||
config.dynBatchSupport = false;
|
||||
confs.push_back(config);
|
||||
} else {
|
||||
if (mayiuse(avx512_common)) {
|
||||
blk_layout = ConfLayout::BLK16;
|
||||
interp_kernel.reset(new jit_uni_interp_kernel_f32<avx512_common>());
|
||||
addConfig(layer, { DataConfigurator(blk_layout, Precision::FP32) }, { DataConfigurator(blk_layout, Precision::FP32) });
|
||||
} else if (mayiuse(avx2)) {
|
||||
blk_layout = ConfLayout::BLK8;
|
||||
interp_kernel.reset(new jit_uni_interp_kernel_f32<avx2>());
|
||||
addConfig(layer, { DataConfigurator(blk_layout, Precision::FP32) }, { DataConfigurator(blk_layout, Precision::FP32) });
|
||||
} else {
|
||||
blk_layout = ConfLayout::BLK8;
|
||||
interp_kernel.reset(new jit_uni_interp_kernel_f32<sse42>());
|
||||
addConfig(layer, { DataConfigurator(blk_layout, Precision::FP32) }, { DataConfigurator(blk_layout, Precision::FP32) });
|
||||
}
|
||||
}
|
||||
} catch (InferenceEngine::details::InferenceEngineException &ex) {
|
||||
errorMsg = ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
StatusCode init(LayerConfig& config, ResponseDesc *resp) noexcept override {
|
||||
if (config.inConfs.size() != 1 || config.outConfs.size() != 1) {
|
||||
strncpy(resp->msg, "Interp layer has invalid configs", sizeof(resp->msg));
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
|
||||
if (config.inConfs[0].desc.getDims().size() != 4) {
|
||||
std::ostringstream result;
|
||||
result << "Interp layer has invalid layout: " << config.inConfs[0].desc.getLayout();
|
||||
strncpy(resp->msg, result.str().c_str(), sizeof(resp->msg) - 1);
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
|
||||
auto inPrecision = config.inConfs[0].desc.getPrecision();
|
||||
if (inPrecision != Precision::U8 && inPrecision != Precision::FP32) {
|
||||
strncpy(resp->msg, "Interp layer has unsupported input precision", sizeof(resp->msg));
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
|
||||
if (config.outConfs[0].desc.getPrecision() != Precision::FP32) {
|
||||
strncpy(resp->msg, "Interp layer has unsupported output precision", sizeof(resp->msg));
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
|
||||
ResponseDesc *resp) noexcept override {
|
||||
#ifdef WIN32
|
||||
#undef IN
|
||||
#endif
|
||||
size_t IN = inputs[0]->getTensorDesc().getDims()[0];
|
||||
size_t IH = inputs[0]->getTensorDesc().getDims()[2];
|
||||
size_t IW = inputs[0]->getTensorDesc().getDims()[3];
|
||||
size_t OH = outputs[0]->getTensorDesc().getDims()[2];
|
||||
size_t OW = outputs[0]->getTensorDesc().getDims()[3];
|
||||
|
||||
size_t IH_pad = IH + pad_beg + pad_end;
|
||||
size_t IW_pad = IW + pad_beg + pad_end;
|
||||
|
||||
auto *dst_data = outputs[0]->buffer().as<float *>() + outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
|
||||
switch (inputs[0]->getTensorDesc().getPrecision()) {
|
||||
case Precision::FP32:
|
||||
{
|
||||
const float* src_data = inputs[0]->cbuffer().as<const float *>() + inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
size_t IC = (inputs[0]->getTensorDesc().getLayout() == Layout::BLOCKED)
|
||||
? inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[1] *
|
||||
inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[4]
|
||||
: IC = inputs[0]->getTensorDesc().getDims()[1];
|
||||
interpolate(IN, IC, src_data,
|
||||
-pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW);
|
||||
}
|
||||
break;
|
||||
case Precision::U8:
|
||||
{
|
||||
const uint8_t* src_data = inputs[0]->cbuffer().as<const uint8_t *>() + inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
size_t IC = inputs[0]->getTensorDesc().getDims()[1];
|
||||
interpolate_8u(inputs[0]->getTensorDesc().getLayout(), IN, IC, src_data,
|
||||
-pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
if (resp) {
|
||||
std::string errorMsg = "Incorrect input precision. Only U8 or FP32 are supported!";
|
||||
errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
|
||||
}
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
private:
|
||||
int pad_beg;
|
||||
int pad_end;
|
||||
bool align_corners;
|
||||
std::shared_ptr<jit_uni_interp_kernel> interp_kernel;
|
||||
|
||||
void interpolate(const size_t N, const size_t C,
|
||||
const float *src, const int x1, const int y1,
|
||||
const int IH_pad, const int IW_pad, const size_t IH, const size_t IW,
|
||||
float *dst, const int x2, const int y2,
|
||||
const int OH_pad, const int OW_pad, const size_t OH, const size_t OW) {
|
||||
if (IH_pad == OH_pad && IW_pad == OW_pad) {
|
||||
for (size_t i = 0; i < N * C * OH * OW; i++) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
float rh;
|
||||
float rw;
|
||||
if (align_corners) {
|
||||
rh = (OH_pad > 1) ? static_cast<float>(IH_pad - 1) / (OH_pad - 1) : 0.0f;
|
||||
rw = (OW_pad > 1) ? static_cast<float>(IW_pad - 1) / (OW_pad - 1) : 0.0f;
|
||||
} else {
|
||||
rh = static_cast<float>(IH_pad) / (OH_pad);
|
||||
rw = static_cast<float>(IW_pad) / (OW_pad);
|
||||
}
|
||||
|
||||
int block_size = 1;
|
||||
if (interp_kernel) {
|
||||
if (mayiuse(avx512_common)) {
|
||||
block_size = 16;
|
||||
} else {
|
||||
block_size = 8;
|
||||
}
|
||||
}
|
||||
|
||||
// Align channel number to block size to deal with channels padding in IE with multiple blobs
|
||||
size_t CB = (C + block_size - 1) & (-block_size);
|
||||
|
||||
size_t CH = (C + block_size - 1) / block_size;
|
||||
|
||||
parallel_for3d(N, CH, OH_pad, [&](size_t n, size_t cb, size_t h) {
|
||||
const float *psrc_n_cb = src + n * CB * IH * IW + cb * block_size * IW * IH; // n+cb src address
|
||||
|
||||
// h is output h
|
||||
float fh = rh * h;
|
||||
// ih0 is higher input h position
|
||||
int ih0 = static_cast<int>(fh);
|
||||
// ih1 is lower input h position
|
||||
int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0;
|
||||
|
||||
float h_lambda0 = fh - ih0; // for lower input h weight
|
||||
float h_lambda1 = 1.0f - h_lambda0; // for higher input h weight
|
||||
|
||||
const float *psrc_h0 = psrc_n_cb + (y1 + ih0) * IW * block_size + x1 * block_size;
|
||||
const float *psrc_h1 = psrc_n_cb + (y1 + ih1) * IW * block_size + x1 * block_size;
|
||||
float *pdst_h = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size + x2 * block_size;
|
||||
|
||||
auto arg = jit_args_interp();
|
||||
arg.h_lambda0 = static_cast<float*>(&h_lambda0);
|
||||
arg.h_lambda1 = static_cast<float*>(&h_lambda1);
|
||||
for (int w = 0; w < OW_pad; ++w) {
|
||||
float fw = rw * w;
|
||||
int iw0 = static_cast<int>(fw);
|
||||
int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0;
|
||||
|
||||
float w_lambda0 = fw - iw0; // for right input w weight
|
||||
float w_lambda1 = 1.0f - w_lambda0; // for left input w weight
|
||||
|
||||
const float *psrc00 = psrc_h0 + iw0 * block_size;
|
||||
const float *psrc01 = psrc_h0 + iw1 * block_size;
|
||||
const float *psrc10 = psrc_h1 + iw0 * block_size;
|
||||
const float *psrc11 = psrc_h1 + iw1 * block_size;
|
||||
|
||||
float *pdst = pdst_h + w * block_size;
|
||||
|
||||
if (interp_kernel) {
|
||||
arg.src00 = psrc00;
|
||||
arg.src01 = psrc01;
|
||||
arg.src10 = psrc10;
|
||||
arg.src11 = psrc11;
|
||||
arg.dst = pdst;
|
||||
arg.w_lambda0 = static_cast<float*>(&w_lambda0);
|
||||
arg.w_lambda1 = static_cast<float*>(&w_lambda1);
|
||||
(*interp_kernel)(&arg);
|
||||
} else {
|
||||
for (int c = 0; c < block_size; ++c) {
|
||||
pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) +
|
||||
h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void interpolate_8u(Layout layout, const size_t N, const size_t C,
|
||||
const uint8_t *src, const int x1, const int y1,
|
||||
const int IH_pad, const int IW_pad, const size_t IH, const size_t IW,
|
||||
float *dst, const int x2, const int y2,
|
||||
const int OH_pad, const int OW_pad, const size_t OH, const size_t OW) {
|
||||
if (IH_pad == OH_pad && IW_pad == OW_pad) {
|
||||
for (size_t i = 0; i < N * C * OH * OW; i++) {
|
||||
dst[i] = static_cast<float>(src[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
float rh;
|
||||
float rw;
|
||||
if (align_corners) {
|
||||
rh = (OH_pad > 1) ? static_cast<float>(IH_pad - 1) / (OH_pad - 1) : 0.0f;
|
||||
rw = (OW_pad > 1) ? static_cast<float>(IW_pad - 1) / (OW_pad - 1) : 0.0f;
|
||||
} else {
|
||||
rh = static_cast<float>(IH_pad) / (OH_pad);
|
||||
rw = static_cast<float>(IW_pad) / (OW_pad);
|
||||
}
|
||||
|
||||
parallel_for3d(N, C, OH_pad, [&](size_t n, size_t cb, size_t h) {
|
||||
const uint8_t *psrc = src + n * C * IH * IW;
|
||||
|
||||
float fh = rh * h;
|
||||
int ih0 = static_cast<int>(fh);
|
||||
int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0;
|
||||
|
||||
float h_lambda0 = fh - ih0;
|
||||
float h_lambda1 = 1.0f - h_lambda0;
|
||||
|
||||
for (int w = 0; w < OW_pad; ++w) {
|
||||
float fw = rw * w;
|
||||
int iw0 = static_cast<int>(fw);
|
||||
int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0;
|
||||
|
||||
float w_lambda0 = fw - iw0;
|
||||
float w_lambda1 = 1.0f - w_lambda0;
|
||||
|
||||
dst[n * C * OH * OW + cb * OW * OH + (y2 + h) * OW + (x2 + w)] =
|
||||
h_lambda1 * (w_lambda1 * static_cast<float>(psrc[cb * IW * IH + (y1 + ih0) * IW + (x1 + iw0)]) +
|
||||
w_lambda0 * static_cast<float>(psrc[cb * IW * IH + (y1 + ih0) * IW + (x1 + iw1)])) +
|
||||
h_lambda0 * (w_lambda1 * static_cast<float>(psrc[cb * IW * IH + (y1 + ih1) * IW + (x1 + iw0)]) +
|
||||
w_lambda0 * static_cast<float>(psrc[cb * IW * IH + (y1 + ih1) * IW + (x1 + iw1)]));
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
REG_FACTORY_FOR(InterpImpl, Interp);
|
||||
|
||||
} // namespace Cpu
|
||||
} // namespace Extensions
|
||||
} // namespace InferenceEngine
|
@ -64,7 +64,6 @@ MKLDNN_EXTENSION_NODE(TopKImpl, TopK);
|
||||
MKLDNN_EXTENSION_NODE(ShuffleChannelsImpl, ShuffleChannels);
|
||||
MKLDNN_EXTENSION_NODE(SpaceToDepthImpl, SpaceToDepth);
|
||||
MKLDNN_EXTENSION_NODE(PowerFileImpl, PowerFile);
|
||||
MKLDNN_EXTENSION_NODE(InterpImpl, Interp);
|
||||
MKLDNN_EXTENSION_NODE(BatchToSpaceImpl, BatchToSpace);
|
||||
MKLDNN_EXTENSION_NODE(ExperimentalDetectronPriorGridGeneratorImpl, ExperimentalDetectronPriorGridGenerator);
|
||||
MKLDNN_EXTENSION_NODE(SimplerNMSImpl, SimplerNMS);
|
||||
|
@ -1872,7 +1872,6 @@ void MKLDNNInterpolateNode::buildTblLinearOnnx(SizeVector& srcDimPad5d, SizeVect
|
||||
size_t scratchLen = rnd_up(OW + OW + OH + OH, 16);
|
||||
int idxType = 2;
|
||||
indexTable.resize(idxType * scratchLen);
|
||||
std::vector<int> index(scratchLen, 0);
|
||||
int *indexLeft = static_cast<int*>(&indexTable[0]);
|
||||
int *indexRight = static_cast<int*>(&indexTable[OW]);
|
||||
int *indexTop = static_cast<int*>(&indexTable[2 * OW]);
|
||||
@ -2320,7 +2319,7 @@ void MKLDNNInterpolateNode::NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr
|
||||
arg.src_ptr[0] = in_ptr_cbd + blk_size * IW * index_h[h] * srcDataSize;
|
||||
arg.index = static_cast<int*>(&(index_w_kernel[0]));
|
||||
arg.work_amount = static_cast<size_t>(OW);
|
||||
arg.oc_off = cb * blk_size;
|
||||
arg.oc_off = cb * blk_size * sizeof(float);
|
||||
(*interpolateKernel)(&arg);
|
||||
}
|
||||
});
|
||||
@ -2351,7 +2350,7 @@ void MKLDNNInterpolateNode::NNPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_,
|
||||
arg.src_ptr[0] = in_ptr;
|
||||
arg.dst = out_ptr;
|
||||
arg.index = static_cast<int*>(&index_kernel[0]); // need index_h and index_w in kernel, it's in continous memory so one param
|
||||
arg.oc_off = static_cast<size_t>(c);
|
||||
arg.oc_off = static_cast<size_t>(c * sizeof(float));
|
||||
// work_amount is OH(out loop) and OW(inner loop), can get in kernel from jcp.
|
||||
(*interpolateKernel)(&arg);
|
||||
});
|
||||
@ -2391,7 +2390,7 @@ void MKLDNNInterpolateNode::linearOnnxPlanar(const uint8_t *in_ptr_, uint8_t *ou
|
||||
arg.weight_ptr[0] = static_cast<float*>(&weight[0]);
|
||||
arg.dst = out_ptr_nc;
|
||||
arg.work_amount = OW * OH;
|
||||
arg.oc_off = c;
|
||||
arg.oc_off = static_cast<size_t>(c * sizeof(float));
|
||||
(*interpolateKernel)(&arg);
|
||||
});
|
||||
}
|
||||
@ -2666,7 +2665,7 @@ void MKLDNNInterpolateNode::cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr
|
||||
arg.weight_ptr[0] = xFactor;
|
||||
arg.weight_ptr[1] = yFactor;
|
||||
arg.work_amount = static_cast<size_t>(OW * OH);
|
||||
arg.oc_off = static_cast<size_t>(C);
|
||||
arg.oc_off = static_cast<size_t>(c * sizeof(float));
|
||||
(*interpolateKernel)(&arg);
|
||||
});
|
||||
}
|
||||
@ -2788,7 +2787,7 @@ inline float MKLDNNInterpolateNode::coordTransToInput(int outCoord, float scale,
|
||||
}
|
||||
case InterpolateCoordTransMode::align_corners: {
|
||||
if (outShape > 1)
|
||||
return outCoord * static_cast<float>(inShape - 1) / static_cast<float>(outShape - 1);
|
||||
return outCoord * (static_cast<float>(inShape - 1) / static_cast<float>(outShape - 1));
|
||||
else
|
||||
return 0;
|
||||
break;
|
||||
@ -2844,10 +2843,9 @@ bool MKLDNNInterpolateNode::canFuse(const MKLDNNNodePtr& node) const {
|
||||
return false;
|
||||
};
|
||||
|
||||
if (!mayiuse(cpu::sse42))
|
||||
return false;
|
||||
if (mode == InterpolateMode::linear || mode == InterpolateMode::cubic)
|
||||
if (!mayiuse(cpu::sse42) || mode == InterpolateMode::linear) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->getType() == Quantize) {
|
||||
auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(node.get());
|
||||
@ -2858,10 +2856,9 @@ bool MKLDNNInterpolateNode::canFuse(const MKLDNNNodePtr& node) const {
|
||||
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode*>(node.get());
|
||||
if (eltwiseNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get eltwise node " << node->getName();
|
||||
return isOneOf(eltwiseNode->getOpType(), {MulAdd, Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp,
|
||||
return isOneOf(eltwiseNode->getOpType(), {Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp,
|
||||
Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}) ||
|
||||
((eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2) ||
|
||||
(eltwiseNode->getOpType() == Prelu));
|
||||
(eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2);
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -1,922 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "mkldnn_resample_node.h"
|
||||
#include "desc_iterator.hpp"
|
||||
#include "mkldnn_quantize_node.h"
|
||||
#include <legacy/ie_layers.h>
|
||||
#include "mkldnn_eltwise_node.h"
|
||||
#include <mkldnn.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <mkldnn_types.h>
|
||||
#include <mkldnn_extension_utils.h>
|
||||
#include "utils/bfloat16.hpp"
|
||||
#include <legacy/ie_layers_internal.hpp>
|
||||
#include "ie_parallel.hpp"
|
||||
#include <algorithm>
|
||||
|
||||
#include "jit_generator.hpp"
|
||||
#include "jit_uni_eltwise.hpp"
|
||||
#include "jit_uni_depthwise.hpp"
|
||||
#include "jit_uni_quantization.hpp"
|
||||
#include "common/cpu_memcpy.h"
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
using namespace mkldnn::impl;
|
||||
using namespace mkldnn::impl::cpu;
|
||||
using namespace mkldnn::impl::utils;
|
||||
using namespace Xbyak;
|
||||
|
||||
|
||||
#define GET_OFF(field) offsetof(jit_resample_call_args, field)
|
||||
|
||||
static inline bool isFloatCompatible(Precision prc) {
|
||||
return Precision::FP32 == prc || Precision::BF16 == prc;
|
||||
}
|
||||
|
||||
static inline bool isFloatCompatible(memory::data_type type) {
|
||||
return memory::f32 == type || memory::bf16 == type;
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
struct jit_uni_resample_nearest_kernel_f32 : public jit_uni_resample_nearest_kernel, public jit_generator {
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_resample_nearest_kernel_f32)
|
||||
|
||||
explicit jit_uni_resample_nearest_kernel_f32(jit_resample_config_params jcp, const mkldnn_primitive_attr &attr)
|
||||
: jit_uni_resample_nearest_kernel(jcp, attr), jit_generator() {
|
||||
const auto &p = attr_.post_ops_;
|
||||
for (int i = 0; i < p.len_; i++) {
|
||||
auto &post_op = p.entry_[i];
|
||||
if (post_op.is_eltwise()) {
|
||||
eltwise_injectors.push_back(std::make_shared<jit_uni_eltwise_injector_f32<isa>>(
|
||||
this,
|
||||
post_op.eltwise.alg,
|
||||
post_op.eltwise.alpha,
|
||||
post_op.eltwise.beta));
|
||||
} else if (post_op.is_depthwise()) {
|
||||
depthwise_injectors.push_back(std::make_shared<jit_uni_depthwise_injector_f32<isa>>(
|
||||
this,
|
||||
post_op.depthwise.alg));
|
||||
} else if (post_op.is_quantization()) {
|
||||
quantization_injectors.push_back(std::make_shared<jit_uni_quantization_injector_f32<isa>>(
|
||||
this, post_op, vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias));
|
||||
}
|
||||
}
|
||||
|
||||
this->preamble();
|
||||
|
||||
mov(reg_src, ptr[reg_params + GET_OFF(src)]);
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF(dst)]);
|
||||
mov(reg_index, ptr[reg_params + GET_OFF(index)]);
|
||||
mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]);
|
||||
mov(reg_src_stride, ptr[reg_params + GET_OFF(src_stride)]);
|
||||
mov(reg_index_stride, ptr[reg_params + GET_OFF(index_stride)]);
|
||||
mov(reg_dst_stride, ptr[reg_params + GET_OFF(dst_stride)]);
|
||||
if (attr_.post_ops_.len_ != 0)
|
||||
mov(reg_oc_off, ptr[reg_params + GET_OFF(oc_off)]);
|
||||
|
||||
if (isa == cpu::avx512_common)
|
||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||
|
||||
int blk_size = jcp_.src_dt == memory::bf16 ? 16 : (vlen / sizeof(float));
|
||||
if (isa == cpu::sse42)
|
||||
blk_size *= 2;
|
||||
|
||||
Xbyak::Label resample_nearest_loop_label;
|
||||
Xbyak::Label resample_nearest_loop_end_label;
|
||||
L(resample_nearest_loop_label);
|
||||
{
|
||||
cmp(reg_work_amount, 0);
|
||||
jle(resample_nearest_loop_end_label, T_NEAR);
|
||||
|
||||
if (jcp_.planar_layout) {
|
||||
uni_vmovdqu(vmm_index, ptr[reg_index]);
|
||||
uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask);
|
||||
vgatherdps(vmm_val, ptr[reg_src + vmm_index * jcp.src_data_size], vmm_mask);
|
||||
store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt);
|
||||
|
||||
add(reg_dst, reg_dst_stride);
|
||||
add(reg_index, reg_index_stride);
|
||||
sub(reg_work_amount, 1);
|
||||
} else if (jcp_.nhwc_format) { // support int8 and fusion for this format
|
||||
load_vector(vmm_val, ptr[reg_src], jcp_.src_dt);
|
||||
if (attr_.post_ops_.len_ != 0)
|
||||
apply_post_ops(jcp_.dst_dt);
|
||||
store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt);
|
||||
|
||||
if (isa == cpu::sse42) {
|
||||
int sse42_offset = 4;
|
||||
load_vector(vmm_val, ptr[reg_src + sse42_offset * jcp_.src_data_size], jcp_.src_dt);
|
||||
if (attr_.post_ops_.len_ != 0) {
|
||||
add(reg_oc_off, sse42_offset * sizeof(float));
|
||||
apply_post_ops(jcp_.dst_dt);
|
||||
sub(reg_oc_off, sse42_offset * sizeof(float));
|
||||
}
|
||||
store_vector(ptr[reg_dst + sse42_offset * jcp_.dst_data_size], vmm_val, jcp_.dst_dt);
|
||||
}
|
||||
|
||||
add(reg_dst, reg_dst_stride);
|
||||
add(reg_src, reg_src_stride);
|
||||
add(reg_oc_off, blk_size * sizeof(float));
|
||||
sub(reg_work_amount, 1);
|
||||
} else { // for blk
|
||||
mov(reg_src_aux, reg_src);
|
||||
mov(reg_index_oc, dword[reg_index]);
|
||||
add(reg_src_aux, reg_index_oc);
|
||||
|
||||
load_vector(vmm_val, ptr[reg_src_aux], jcp_.src_dt);
|
||||
if (attr_.post_ops_.len_ != 0)
|
||||
apply_post_ops(jcp_.dst_dt);
|
||||
store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt);
|
||||
|
||||
if (isa == cpu::sse42) {
|
||||
int sse42_offset = 4;
|
||||
add(reg_src_aux, sse42_offset * jcp_.src_data_size);
|
||||
load_vector(vmm_val, ptr[reg_src_aux], jcp_.src_dt);
|
||||
if (attr_.post_ops_.len_ != 0) {
|
||||
add(reg_oc_off, sse42_offset * sizeof(float));
|
||||
apply_post_ops(jcp_.dst_dt);
|
||||
sub(reg_oc_off, sse42_offset * sizeof(float));
|
||||
}
|
||||
store_vector(ptr[reg_dst + sse42_offset * jcp_.dst_data_size], vmm_val, jcp_.dst_dt);
|
||||
}
|
||||
|
||||
add(reg_dst, reg_dst_stride);
|
||||
add(reg_index, reg_index_stride);
|
||||
sub(reg_work_amount, 1);
|
||||
}
|
||||
|
||||
jmp(resample_nearest_loop_label, T_NEAR);
|
||||
}
|
||||
L(resample_nearest_loop_end_label);
|
||||
|
||||
this->postamble();
|
||||
|
||||
for (auto& inj : eltwise_injectors)
|
||||
inj->prepare_table();
|
||||
|
||||
ker_ = (decltype(ker_)) this->getCode();
|
||||
}
|
||||
|
||||
private:
|
||||
using Vmm = typename conditional3<isa == cpu::sse42, Xbyak::Xmm, isa == cpu::avx2,
|
||||
Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||
|
||||
const int vlen = cpu_isa_traits<isa>::vlen;
|
||||
|
||||
Xbyak::Reg64 reg_src = r8;
|
||||
Xbyak::Reg64 reg_dst = r9;
|
||||
Xbyak::Reg64 reg_src_stride = r10;
|
||||
Xbyak::Reg64 reg_dst_stride = r11;
|
||||
Xbyak::Reg64 reg_index_stride = r12;
|
||||
Xbyak::Reg64 reg_work_amount = r13;
|
||||
Xbyak::Reg64 reg_index = r14;
|
||||
Xbyak::Reg64 reg_src_aux = r15;
|
||||
Xbyak::Reg64 reg_params = abi_param1;
|
||||
|
||||
Xbyak::Reg64 reg_oc_off = rax;
|
||||
Xbyak::Reg64 reg_d_weights = rbx;
|
||||
Xbyak::Reg64 reg_d_bias = rcx;
|
||||
Xbyak::Reg32 reg_index_oc = edx;
|
||||
|
||||
Vmm vmm_val = Vmm(0);
|
||||
Vmm vmm_index = Vmm(1);
|
||||
Vmm vmm_zero = Vmm(2);
|
||||
Vmm vmm_mask = Vmm(3);
|
||||
Vmm vmm_d_weights = Vmm(4);
|
||||
Vmm vmm_d_bias = Vmm(5);
|
||||
|
||||
std::vector<std::shared_ptr<jit_uni_eltwise_injector_f32<isa>>> eltwise_injectors;
|
||||
std::vector<std::shared_ptr<jit_uni_depthwise_injector_f32<isa>>> depthwise_injectors;
|
||||
std::vector<std::shared_ptr<jit_uni_quantization_injector_f32<isa>>> quantization_injectors;
|
||||
|
||||
inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
|
||||
switch (src_dt) {
|
||||
case memory::f32:
|
||||
case memory::s32:
|
||||
uni_vmovups(vmm_src, op);
|
||||
break;
|
||||
case memory::s8:
|
||||
uni_vpmovsxbd(vmm_src, op);
|
||||
break;
|
||||
case memory::u8:
|
||||
uni_vpmovzxbd(vmm_src, op);
|
||||
break;
|
||||
case memory::bf16:
|
||||
uni_vpmovzxwd(vmm_src, op);
|
||||
uni_vpslld(vmm_src, vmm_src, 16);
|
||||
break;
|
||||
default:
|
||||
assert(!"unknown dst_dt");
|
||||
}
|
||||
|
||||
if (!isFloatCompatible(src_dt))
|
||||
uni_vcvtdq2ps(vmm_src, vmm_src);
|
||||
}
|
||||
|
||||
inline void store_vector(const Xbyak::Address &op, Vmm vmm_dst, memory::data_type dst_dt) {
|
||||
Ymm ymm_dst = Ymm(vmm_dst.getIdx());
|
||||
Xmm xmm_dst = Xmm(vmm_dst.getIdx());
|
||||
|
||||
if (dst_dt == memory::f32) {
|
||||
uni_vmovups(op, vmm_dst);
|
||||
} else if (dst_dt == memory::bf16) {
|
||||
vcvtneps2bf16(ymm_dst, vmm_dst);
|
||||
vmovdqu16(op, ymm_dst);
|
||||
} else if (dst_dt == memory::u8) {
|
||||
uni_vcvtps2dq(vmm_dst, vmm_dst);
|
||||
if (isa == cpu::avx512_common) {
|
||||
vpmaxsd(vmm_dst, vmm_dst, vmm_zero);
|
||||
vpmovusdb(op, vmm_dst);
|
||||
} else {
|
||||
uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
|
||||
if (isa != cpu::sse42)
|
||||
vpermq(ymm_dst, ymm_dst, 0x08);
|
||||
uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
|
||||
if (isa != cpu::sse42)
|
||||
vmovq(op, xmm_dst);
|
||||
else
|
||||
movd(op, xmm_dst);
|
||||
}
|
||||
} else if (dst_dt == memory::s8) {
|
||||
uni_vcvtps2dq(vmm_dst, vmm_dst);
|
||||
if (isa == cpu::avx512_common) {
|
||||
vpmovsdb(op, vmm_dst);
|
||||
} else {
|
||||
uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
|
||||
if (isa != cpu::sse42)
|
||||
vpermq(ymm_dst, ymm_dst, 0x08);
|
||||
uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
|
||||
if (isa != cpu::sse42)
|
||||
vmovq(op, xmm_dst);
|
||||
else
|
||||
movd(op, xmm_dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void apply_post_ops(memory::data_type dst_dt) {
|
||||
const auto &p = attr_.post_ops_;
|
||||
int eltwise_inj_idx = 0;
|
||||
int depthwise_inj_idx = 0;
|
||||
int quantization_inj_idx = 0;
|
||||
for (int i = 0; i < p.len_; i++) {
|
||||
auto& post_op = p.entry_[i];
|
||||
if (post_op.is_eltwise()) {
|
||||
eltwise_injectors[eltwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1);
|
||||
eltwise_inj_idx++;
|
||||
} else if (post_op.is_depthwise()) {
|
||||
mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
|
||||
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
|
||||
add(reg_d_weights, reg_oc_off);
|
||||
add(reg_d_bias, reg_oc_off);
|
||||
depthwise_injectors[depthwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1, reg_d_weights, reg_d_bias);
|
||||
depthwise_inj_idx++;
|
||||
} else if (post_op.is_quantization()) {
|
||||
bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize;
|
||||
bool do_rounding = do_dequantization || isFloatCompatible(dst_dt) || i != p.len_ - 1;
|
||||
int s_idx = vmm_val.getIdx();
|
||||
|
||||
quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0);
|
||||
|
||||
quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding);
|
||||
|
||||
quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_oc_off);
|
||||
quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0);
|
||||
|
||||
quantization_inj_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
MKLDNNResampleNode::MKLDNNResampleNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache)
|
||||
: MKLDNNNode(layer, eng, cache) {}
|
||||
|
||||
void MKLDNNResampleNode::getSupportedDescriptors() {
|
||||
if (!descs.empty())
|
||||
return;
|
||||
|
||||
if (getParentEdges().size() != 1)
|
||||
THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
|
||||
if (getChildEdges().empty())
|
||||
THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
|
||||
|
||||
auto *layer = getCnnLayer().get();
|
||||
type = layer->GetParamAsString("type");
|
||||
antialias = layer->GetParamAsBool("antialias", false);
|
||||
factor = layer->GetParamAsFloat("factor");
|
||||
}
|
||||
|
||||
void MKLDNNResampleNode::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
if (getParentEdgeAt(0)->getDims().ndims() < 4 || getParentEdgeAt(0)->getDims().ndims() > 5) {
|
||||
return;
|
||||
}
|
||||
|
||||
setPostOps(attr, true);
|
||||
|
||||
Precision inputPrecision = getCnnLayer()->insData[0].lock()->getPrecision();
|
||||
Precision outputPrecision = getCnnLayer()->outData[0]->getPrecision();
|
||||
|
||||
if (!fusedWith.empty()) {
|
||||
auto lastFusedLayer = fusedWith[fusedWith.size() - 1].get()->getCnnLayer();
|
||||
if (lastFusedLayer) {
|
||||
outputPrecision = lastFusedLayer->outData[0]->getPrecision();
|
||||
}
|
||||
}
|
||||
|
||||
if (inputPrecision == Precision::BF16 || outputPrecision == Precision::BF16) {
|
||||
if (!mayiuse(avx512_core_bf16))
|
||||
inputPrecision = outputPrecision = Precision::FP32;
|
||||
else
|
||||
inputPrecision = outputPrecision = Precision::BF16;
|
||||
}
|
||||
|
||||
auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(inputPrecision);
|
||||
auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(outputPrecision);
|
||||
|
||||
input_prec = inputPrecision;
|
||||
output_prec = outputPrecision;
|
||||
src_data_size = MKLDNNExtensionUtils::sizeOfDataType(inputDataType);
|
||||
dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(outputDataType);
|
||||
|
||||
InferenceEngine::LayerConfig config;
|
||||
config.dynBatchSupport = false;
|
||||
config.inConfs.resize(1);
|
||||
config.outConfs.resize(1);
|
||||
config.inConfs[0].constant = false;
|
||||
config.outConfs[0].constant = false;
|
||||
config.inConfs[0].inPlace = -1;
|
||||
config.outConfs[0].inPlace = -1;
|
||||
|
||||
auto pushDesc = [&](memory::format format) {
|
||||
config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, format);
|
||||
config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, format);
|
||||
supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown, format});
|
||||
};
|
||||
|
||||
if (type == "caffe.ResampleParameter.NEAREST") {
|
||||
if (getParentEdgeAt(0)->getDims().ndims() == 4) {
|
||||
pushDesc(memory::nhwc);
|
||||
} else if (getParentEdgeAt(0)->getDims().ndims() == 5) {
|
||||
pushDesc(memory::ndhwc);
|
||||
}
|
||||
|
||||
if (isFloatCompatible(inputPrecision) && isFloatCompatible(outputPrecision)) {
|
||||
if (getParentEdgeAt(0)->getDims().ndims() == 4) {
|
||||
if (mayiuse(cpu::avx512_common)) {
|
||||
pushDesc(memory::nChw16c);
|
||||
} else if (mayiuse(cpu::avx2) || mayiuse(cpu::sse42)) {
|
||||
pushDesc(memory::nChw8c);
|
||||
}
|
||||
} else if (getParentEdgeAt(0)->getDims().ndims() == 5) {
|
||||
if (mayiuse(cpu::avx512_common)) {
|
||||
pushDesc(memory::nCdhw16c);
|
||||
} else if (mayiuse(cpu::avx2) || mayiuse(cpu::sse42)) {
|
||||
pushDesc(memory::nCdhw8c);
|
||||
}
|
||||
}
|
||||
|
||||
if (fusedWith.empty()) {
|
||||
pushDesc(MKLDNNMemory::GetPlainFormat(getChildEdgeAt(0)->getDims()));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (type == "caffe.ResampleParameter.LINEAR") {
|
||||
if (getParentEdgeAt(0)->getDims().ndims() == 4) {
|
||||
pushDesc(memory::nchw);
|
||||
} else if (getParentEdgeAt(0)->getDims().ndims() == 5) {
|
||||
pushDesc(memory::ncdhw);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNResampleNode::createPrimitive() {
|
||||
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
|
||||
auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
|
||||
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
|
||||
THROW_IE_EXCEPTION << "Destination memory didn't allocate.";
|
||||
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
|
||||
THROW_IE_EXCEPTION << "Input memory didn't allocate.";
|
||||
if (getSelectedPrimitiveDescriptor() == nullptr)
|
||||
THROW_IE_EXCEPTION << "Preferable primitive descriptor is not set.";
|
||||
|
||||
auto selectedPD = getSelectedPrimitiveDescriptor();
|
||||
Layout selected_layout = selectedPD->getConfig().inConfs[0].desc.getLayout();
|
||||
auto jcp = jit_resample_config_params();
|
||||
jcp.src_dt = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[0].desc.getPrecision());
|
||||
jcp.dst_dt = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().outConfs[0].desc.getPrecision());
|
||||
jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(jcp.src_dt);
|
||||
jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(jcp.dst_dt);
|
||||
jcp.planar_layout = MKLDNNMemory::GetPlainLayout(getChildEdgeAt(0)->getDims()) == selected_layout;
|
||||
jcp.nhwc_format = (selected_layout == NHWC) || (selected_layout == NDHWC);
|
||||
|
||||
if (type == "caffe.ResampleParameter.NEAREST") {
|
||||
if (mayiuse(cpu::avx512_common)) {
|
||||
if (jcp.planar_layout) {
|
||||
resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32<cpu::avx2>(jcp, *attr.get()));
|
||||
blk_size = 8;
|
||||
} else {
|
||||
resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32<cpu::avx512_common>(jcp, *attr.get()));
|
||||
blk_size = 16;
|
||||
}
|
||||
} else if (mayiuse(cpu::avx2)) {
|
||||
resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32<cpu::avx2>(jcp, *attr.get()));
|
||||
blk_size = 8;
|
||||
} else if (mayiuse(cpu::sse42) && !jcp.planar_layout) {
|
||||
resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32<cpu::sse42>(jcp, *attr.get()));
|
||||
blk_size = 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNResampleNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) {
|
||||
int blob_idx = 0;
|
||||
mkldnn::post_ops ops;
|
||||
|
||||
for (auto &node : fusedWith) {
|
||||
auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode *>(node.get());
|
||||
if (quantizeNode) {
|
||||
quantizeNode->appendPostOps(ops);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
|
||||
if (eltwiseNode) {
|
||||
eltwiseNode->appendPostOps(ops);
|
||||
continue;
|
||||
}
|
||||
|
||||
THROW_IE_EXCEPTION << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented";
|
||||
}
|
||||
|
||||
attr.set_post_ops(ops);
|
||||
}
|
||||
|
||||
|
||||
void MKLDNNResampleNode::execute(mkldnn::stream strm) {
|
||||
auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
|
||||
auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
|
||||
|
||||
Layout layout = getParentEdgeAt(0)->getDesc().getLayout();
|
||||
|
||||
SizeVector src_dim = getParentEdgeAt(0)->getDesc().getDims();
|
||||
SizeVector dst_dim = getChildEdgeAt(0)->getDesc().getDims();
|
||||
|
||||
size_t dims_size = src_dim.size();
|
||||
size_t N = src_dim[0];
|
||||
size_t C = src_dim[1];
|
||||
size_t ID = (dims_size == 5) ? src_dim[dims_size - 3] : 1lu;
|
||||
size_t IH = src_dim[dims_size - 2];
|
||||
size_t IW = src_dim[dims_size - 1];
|
||||
|
||||
size_t OD = (dims_size == 5) ? dst_dim[dims_size - 3] : 1lu;
|
||||
size_t OH = dst_dim[dims_size - 2];
|
||||
size_t OW = dst_dim[dims_size - 1];
|
||||
|
||||
float fx = static_cast<float>(IW) / static_cast<float>(OW);
|
||||
float fy = static_cast<float>(IH) / static_cast<float>(OH);
|
||||
float fz = static_cast<float>(ID) / static_cast<float>(OD);
|
||||
|
||||
if (type == "caffe.ResampleParameter.NEAREST") {
|
||||
if (layout == NCHW || layout == NCDHW) {
|
||||
if (output_prec == Precision::FP32) {
|
||||
auto src_data = reinterpret_cast<const float*>(srcMemPtr->GetData());
|
||||
auto dst_data = reinterpret_cast<float*>(dstMemPtr->GetData());
|
||||
NearestNeighbor_PLN<float, float>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else if (output_prec == Precision::BF16) {
|
||||
auto src_data = reinterpret_cast<const bfloat16_t*>(srcMemPtr->GetData());
|
||||
auto dst_data = reinterpret_cast<bfloat16_t*>(dstMemPtr->GetData());
|
||||
NearestNeighbor_PLN<bfloat16_t, bfloat16_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name();
|
||||
}
|
||||
} else {
|
||||
if (output_prec == Precision::U8) {
|
||||
auto dst_data = reinterpret_cast<uint8_t *>(dstMemPtr->GetData());
|
||||
if (input_prec == Precision::U8) {
|
||||
auto src_data = reinterpret_cast<const uint8_t *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<uint8_t, uint8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else if (input_prec == Precision::I8) {
|
||||
auto src_data = reinterpret_cast<const int8_t *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<int8_t, uint8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else if (input_prec == Precision::FP32) {
|
||||
auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<float, uint8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name();
|
||||
}
|
||||
} else if (output_prec == Precision::I8) {
|
||||
auto dst_data = reinterpret_cast<int8_t *>(dstMemPtr->GetData());
|
||||
if (input_prec == Precision::U8) {
|
||||
auto src_data = reinterpret_cast<const uint8_t *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<uint8_t, int8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else if (input_prec == Precision::I8) {
|
||||
auto src_data = reinterpret_cast<const int8_t *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<int8_t, int8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else if (input_prec == Precision::FP32) {
|
||||
auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<float, int8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name();
|
||||
}
|
||||
} else if (output_prec == Precision::FP32) {
|
||||
auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
|
||||
if (input_prec == Precision::U8) {
|
||||
auto src_data = reinterpret_cast<const uint8_t *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<uint8_t, float>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else if (input_prec == Precision::I8) {
|
||||
auto src_data = reinterpret_cast<const int8_t *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<int8_t, float>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else if (input_prec == Precision::FP32) {
|
||||
auto src_data = reinterpret_cast<float *>(srcMemPtr->GetData());
|
||||
NearestNeighbor_BLK<float, float>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name();
|
||||
}
|
||||
} else if (output_prec == Precision::BF16) {
|
||||
auto src_data = reinterpret_cast<const bfloat16_t*>(srcMemPtr->GetData());
|
||||
auto dst_data = reinterpret_cast<bfloat16_t*>(dstMemPtr->GetData());
|
||||
NearestNeighbor_BLK<bfloat16_t, bfloat16_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name();
|
||||
}
|
||||
}
|
||||
} else if (type == "caffe.ResampleParameter.LINEAR") {
|
||||
// currently no fusion, the input and output precision is the same
|
||||
bool isDownsample = (fx > 1) || (fy > 1) || (fz > 1);
|
||||
int kernel_width = 2;
|
||||
if (input_prec == Precision::U8) {
|
||||
auto src_data = reinterpret_cast<const uint8_t *>(srcMemPtr->GetData());
|
||||
auto dst_data = reinterpret_cast<uint8_t *>(dstMemPtr->GetData());
|
||||
LinearInterpolation<uint8_t, uint8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias);
|
||||
} else if (input_prec == Precision::I8) {
|
||||
auto src_data = reinterpret_cast<const int8_t *>(srcMemPtr->GetData());
|
||||
auto dst_data = reinterpret_cast<int8_t *>(dstMemPtr->GetData());
|
||||
LinearInterpolation<int8_t, int8_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias);
|
||||
} else if (input_prec == Precision::FP32) {
|
||||
auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
|
||||
auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
|
||||
LinearInterpolation<float, float>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias);
|
||||
} else if (input_prec == Precision::BF16) {
|
||||
auto src_data = reinterpret_cast<const bfloat16_t*>(srcMemPtr->GetData());
|
||||
auto dst_data = reinterpret_cast<bfloat16_t*>(dstMemPtr->GetData());
|
||||
LinearInterpolation<bfloat16_t, bfloat16_t>(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width,
|
||||
isDownsample && antialias);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported input precision: " << input_prec.name();
|
||||
}
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported resample parameter type: " << type;
|
||||
}
|
||||
}
|
||||
|
||||
// f32 and no fused, f32->input is f32, no fuse->output is f32
|
||||
template <typename in_data_t, typename out_data_t>
|
||||
void MKLDNNResampleNode::NearestNeighbor_PLN(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW,
|
||||
float fx, float fy, float fz, int OD, int OH, int OW) {
|
||||
std::vector<int> index_buffer(OD * OH * OW);
|
||||
for (int oz = 0; oz < OD; oz++) {
|
||||
float iz = oz * fz;
|
||||
int iz_offset = static_cast<int>(std::floor(iz)) * IH * IW;
|
||||
int oz_offset = oz * OH * OW;
|
||||
for (int oy = 0; oy < OH; oy++) {
|
||||
float iy = oy * fy;
|
||||
int iy_offset = static_cast<int>(std::floor(iy)) * IW + iz_offset;
|
||||
int oy_offset = oy * OW + oz_offset;
|
||||
for (int ox = 0; ox < OW; ox++) {
|
||||
float ix = ox * fx;
|
||||
int ix_index = static_cast<int>(std::floor(ix)) + iy_offset;
|
||||
index_buffer[oy_offset + ox] = ix_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (resample_nearest_kernel) {
|
||||
parallel_for2d(B, C, [&](size_t b, size_t c) {
|
||||
const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * C * b + IW * IH * ID * c;
|
||||
out_data_t *out_ptr = out_ptr_ + OW * OH * OD * C * b + OW * OH * OD * c;
|
||||
|
||||
// for OW*OH*OD
|
||||
auto arg = jit_resample_call_args();
|
||||
arg.src = in_ptr;
|
||||
arg.dst = out_ptr;
|
||||
arg.index = static_cast<int*>(&index_buffer[0]);
|
||||
arg.index_stride = blk_size * sizeof(int);
|
||||
arg.dst_stride = blk_size * dst_data_size;
|
||||
arg.work_amount = OW * OH * OD / blk_size;
|
||||
(*resample_nearest_kernel)(&arg);
|
||||
|
||||
int tail_start = (OW * OH * OD / blk_size) * blk_size;
|
||||
for (int tail = tail_start; tail < OW * OH * OD; tail++) {
|
||||
out_ptr[tail] = in_ptr[index_buffer[tail]];
|
||||
}
|
||||
});
|
||||
} else {
|
||||
parallel_for2d(B, C, [&](size_t b, size_t c) {
|
||||
const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * C * b + IW * IH * ID * c;
|
||||
out_data_t *out_ptr = out_ptr_ + OW * OH * OD * C * b + OW * OH * OD * c;
|
||||
|
||||
for (int i_dst = 0; i_dst < OW * OH * OD; i_dst++) {
|
||||
out_ptr[i_dst] = in_ptr[index_buffer[i_dst]];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// for ndhwc and nCdhw8/16d
|
||||
// int8->input may be int8, fused->output may be int8
|
||||
template <typename in_data_t, typename out_data_t>
|
||||
void MKLDNNResampleNode::NearestNeighbor_BLK(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW,
|
||||
float fx, float fy, float fz, int OD, int OH, int OW) {
|
||||
std::vector<int> index_d(OD);
|
||||
std::vector<int> index_h(OH);
|
||||
std::vector<int> index_w(OW);
|
||||
for (int oz = 0; oz < OD; oz++) {
|
||||
float iz = oz * fz;
|
||||
index_d[oz] = static_cast<int>(std::floor(iz));
|
||||
}
|
||||
for (int oy = 0; oy < OH; oy++) {
|
||||
float iy = oy * fy;
|
||||
index_h[oy] = static_cast<int>(std::floor(iy));
|
||||
}
|
||||
for (int ox = 0; ox < OW; ox++) {
|
||||
float ix = ox * fx;
|
||||
index_w[ox] = static_cast<int>(std::floor(ix));
|
||||
}
|
||||
|
||||
Layout layout = getParentEdgeAt(0)->getDesc().getLayout();
|
||||
bool is_nhwc = (layout == NHWC || layout == NDHWC) ? true : false;
|
||||
|
||||
for (int b = 0; b < B; b++) {
|
||||
if (is_nhwc) {
|
||||
const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * C * b;
|
||||
out_data_t *out_ptr = out_ptr_ + OW * OH * OD * C * b;
|
||||
if (resample_nearest_kernel) {
|
||||
int tail = (C / blk_size) * blk_size;
|
||||
parallel_for2d(OD, OH, [&](size_t d, size_t h) {
|
||||
// better that same core process continuous memory
|
||||
out_data_t *out_ptr_dh = out_ptr + C * OW * OH * d + C * OW * h;
|
||||
const in_data_t *in_ptr_dh = in_ptr + C * IW * IH * index_d[d] + C * IW * index_h[h];
|
||||
auto arg = jit_resample_call_args();
|
||||
for (int ox = 0; ox < OW; ox++) {
|
||||
// kernel for OC
|
||||
arg.dst = out_ptr_dh + C * ox;
|
||||
arg.src = in_ptr_dh + C * index_w[ox];
|
||||
arg.dst_stride = blk_size * sizeof(out_data_t);
|
||||
arg.src_stride = blk_size * sizeof(in_data_t);
|
||||
arg.work_amount = C / blk_size;
|
||||
arg.oc_off = 0;
|
||||
(*resample_nearest_kernel)(&arg);
|
||||
}
|
||||
// tail
|
||||
if (tail != C) {
|
||||
for (int ox = 0; ox < OW; ox++) {
|
||||
out_data_t *out_ptr_dhw = out_ptr_dh + C * ox;
|
||||
const in_data_t *in_ptr_dhw = in_ptr_dh + C * index_w[ox];
|
||||
if (fusedWith.empty() && output_prec == input_prec) {
|
||||
cpu_memcpy(out_ptr_dhw + tail, in_ptr_dhw + tail, (C - tail) * sizeof(in_data_t));
|
||||
} else {
|
||||
for (int c = tail; c < C; c++) {
|
||||
float dst_value = static_cast<float>(in_ptr_dhw[c]);
|
||||
apply_post_ops_scalar(dst_value, c);
|
||||
if (isFloatCompatible(output_prec)) {
|
||||
out_ptr_dhw[c] = dst_value;
|
||||
} else if (output_prec == Precision::U8) {
|
||||
out_ptr_dhw[c] = (dst_value >= 0) ? lroundf(dst_value) : 0;
|
||||
} else if (output_prec == Precision::I8) {
|
||||
out_ptr_dhw[c] = lroundf(dst_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
} else { // without kernel
|
||||
parallel_for2d(OD, OH, [&](size_t d, size_t h) {
|
||||
out_data_t *out_ptr_dh = out_ptr + C * OW * OH * d + C * OW * h;
|
||||
const in_data_t *in_ptr_dh = in_ptr + C * IW * IH * index_d[d] + C * IW * index_h[h];
|
||||
for (int ox = 0; ox < OW; ox++) {
|
||||
out_data_t *out_ptr_dhw = out_ptr_dh + C * ox;
|
||||
const in_data_t *in_ptr_dhw = in_ptr_dh + C * index_w[ox];
|
||||
if (fusedWith.empty() && output_prec == input_prec) {
|
||||
cpu_memcpy(out_ptr_dhw, in_ptr_dhw, C * sizeof(in_data_t));
|
||||
} else {
|
||||
for (int c = 0; c < C; c++) {
|
||||
float dst_value = static_cast<float>(in_ptr_dhw[c]);
|
||||
apply_post_ops_scalar(dst_value, c);
|
||||
if (isFloatCompatible(output_prec)) {
|
||||
out_ptr_dhw[c] = dst_value;
|
||||
} else if (output_prec == Precision::U8) {
|
||||
out_ptr_dhw[c] = (dst_value >= 0) ? lroundf(dst_value) : 0;
|
||||
} else if (output_prec == Precision::I8) {
|
||||
out_ptr_dhw[c] = lroundf(dst_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
} else { // for nC(d)hw8/16c
|
||||
int CB = div_up(C, blk_size);
|
||||
const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * CB * blk_size * b;
|
||||
out_data_t *out_ptr = out_ptr_ + OW * OH * OD * CB * blk_size * b;
|
||||
if (resample_nearest_kernel) {
|
||||
std::vector<int> index_w_kernel(OW);
|
||||
for (int ox = 0; ox < OW; ox++) {
|
||||
index_w_kernel[ox] = index_w[ox] * blk_size * sizeof(in_data_t);
|
||||
}
|
||||
parallel_for2d(CB, OD, [&](size_t cb, size_t d) {
|
||||
out_data_t *out_ptr_cbd = out_ptr + blk_size * OW * OH * OD * cb + blk_size * OW * OH * d;
|
||||
const in_data_t *in_ptr_cbd = in_ptr + blk_size * IW * IH * ID * cb + blk_size * IW * IH * index_d[d];
|
||||
auto arg = jit_resample_call_args();
|
||||
for (int h = 0; h < OH; h++) { // kernel for blk_size * OW
|
||||
arg.dst = out_ptr_cbd + blk_size * OW * h;
|
||||
arg.src = in_ptr_cbd + blk_size * IW * index_h[h];
|
||||
arg.index = static_cast<int*>(&(index_w_kernel[0]));
|
||||
arg.dst_stride = static_cast<size_t>(blk_size * sizeof(out_data_t));
|
||||
arg.index_stride = static_cast<size_t>(1 * sizeof(int));
|
||||
arg.work_amount = static_cast<size_t>(OW);
|
||||
arg.oc_off = cb * blk_size;
|
||||
(*resample_nearest_kernel)(&arg);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
parallel_for2d(CB, OD, [&](int cb, int d) {
|
||||
out_data_t *out_ptr_cbd = out_ptr + blk_size * OW * OH * OD * cb + blk_size * OW * OH * d;
|
||||
const in_data_t *in_ptr_cbd = in_ptr + blk_size * IW * IH * ID * cb + blk_size * IW * IH * index_d[d];
|
||||
for (int h = 0; h < OH; h++) {
|
||||
out_data_t *out_ptr_cbdh = out_ptr_cbd + blk_size * OW * h;
|
||||
const in_data_t *in_ptr_cbdh = in_ptr_cbd + blk_size * IW * index_h[h];
|
||||
for (int w = 0; w < OW; w++) {
|
||||
out_data_t *out_ptr_cbdhw = out_ptr_cbdh + blk_size * w;
|
||||
const in_data_t *in_ptr_cbdhw = in_ptr_cbdh + blk_size * index_w[w];
|
||||
if (fusedWith.empty()) {
|
||||
cpu_memcpy(out_ptr_cbdhw, in_ptr_cbdhw, blk_size * sizeof(in_data_t));
|
||||
} else {
|
||||
for (int blk = 0; blk < blk_size; blk++) {
|
||||
float dst_value = static_cast<float>(in_ptr_cbdhw[blk]);
|
||||
apply_post_ops_scalar(dst_value, cb * blk_size + blk);
|
||||
if (isFloatCompatible(output_prec)) {
|
||||
out_ptr_cbdhw[blk] = dst_value;
|
||||
} else if (output_prec == Precision::U8) {
|
||||
out_ptr_cbdhw[blk] = (dst_value >= 0) ? lroundf(dst_value) : 0;
|
||||
} else if (output_prec == Precision::I8) {
|
||||
out_ptr_cbdhw[blk] = lroundf(dst_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
} // batch end
|
||||
}
|
||||
|
||||
static inline float triangleCoeff(float x) {
|
||||
return (std::max)(0.0f, 1 - std::abs(x));
|
||||
}
|
||||
|
||||
template <typename in_data_t, typename out_data_t>
|
||||
void MKLDNNResampleNode::LinearInterpolation(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW,
|
||||
float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias) {
|
||||
if (IW == OW && IH == OH && ID == OD) {
|
||||
size_t size = B * C * ID * IH * IW;
|
||||
if (isFloatCompatible(input_prec)) {
|
||||
size *= sizeof(in_data_t);
|
||||
}
|
||||
cpu_memcpy(out_ptr_, in_ptr_, size);
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t b = 0; b < B; b++) {
|
||||
const in_data_t *in_ptr_n = in_ptr_ + IW * IH * ID * C * b;
|
||||
out_data_t *out_ptr_n = out_ptr_ + OW * OH * OD * C * b;
|
||||
for (size_t c = 0; c < C; c++) {
|
||||
const in_data_t *in_ptr_nc = in_ptr_n + IW * IH * ID * c;
|
||||
out_data_t *out_ptr_nc = out_ptr_n + OW * OH * OD * c;
|
||||
|
||||
for (size_t oz = 0; oz < OD; oz++) {
|
||||
out_data_t *out_ptr_ncd = out_ptr_nc + OW * OH * oz;
|
||||
for (size_t oy = 0; oy < OH; oy++) {
|
||||
out_data_t *out_ptr_ncdh = out_ptr_ncd + OW * oy;
|
||||
for (size_t ox = 0; ox < OW; ox++) {
|
||||
float ix = ox * fx + fx / 2.0f - 0.5f;
|
||||
float iy = oy * fy + fy / 2.0f - 0.5f;
|
||||
float iz = oz * fz + fz / 2.0f - 0.5f;
|
||||
|
||||
int ix_r = static_cast<int>(round(ix));
|
||||
int iy_r = static_cast<int>(round(iy));
|
||||
int iz_r = static_cast<int>(round(iz));
|
||||
|
||||
float sum = 0;
|
||||
float wsum = 0;
|
||||
|
||||
float ax = 1.0f / (antialias ? fx : 1.0f);
|
||||
float ay = 1.0f / (antialias ? fy : 1.0f);
|
||||
float az = 1.0f / (antialias ? fz : 1.0f);
|
||||
|
||||
int rx = (fx < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ax));
|
||||
int ry = (fy < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ay));
|
||||
int rz = (fz < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / az));
|
||||
|
||||
for (int z = iz_r - rz; z <= iz_r + rz; z++) {
|
||||
for (int y = iy_r - ry; y <= iy_r + ry; y++) {
|
||||
for (int x = ix_r - rx; x <= ix_r + rx; x++) {
|
||||
bool is_continue = z < 0 ||
|
||||
y < 0 ||
|
||||
x < 0 ||
|
||||
z >= static_cast<int>(ID) ||
|
||||
y >= static_cast<int>(IH) ||
|
||||
x >= static_cast<int>(IW);
|
||||
if (is_continue)
|
||||
continue;
|
||||
|
||||
float dx = ix - x;
|
||||
float dy = iy - y;
|
||||
float dz = iz - z;
|
||||
|
||||
float w = ax * triangleCoeff(ax * dx) *
|
||||
ay * triangleCoeff(ay * dy) *
|
||||
az * triangleCoeff(az * dz);
|
||||
|
||||
sum += w * static_cast<float>(in_ptr_nc[z * IH * IW + y * IW + x]);
|
||||
wsum += w;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!wsum) {
|
||||
out_ptr_ncdh[ox] = 0;
|
||||
} else {
|
||||
float dst_value = sum / wsum;
|
||||
if (isFloatCompatible(output_prec)) {
|
||||
out_ptr_ncdh[ox] = dst_value;
|
||||
} else if (output_prec == Precision::U8) {
|
||||
out_ptr_ncdh[ox] = (dst_value >= 0) ? lroundf(dst_value) : 0;
|
||||
} else if (output_prec == Precision::I8) {
|
||||
out_ptr_ncdh[ox] = lroundf(dst_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void MKLDNNResampleNode::apply_post_ops_scalar(float &dst_value, int index_c) {
|
||||
const auto &p = (*attr.get()).post_ops_;
|
||||
for (int i = 0; i < p.len_; i++) {
|
||||
auto &post_op = p.entry_[i];
|
||||
if (post_op.is_eltwise()) {
|
||||
// only eltwise_relu supported
|
||||
if (dst_value < 0) dst_value = 0;
|
||||
} else if (post_op.is_depthwise()) {
|
||||
// only ScaleShift supported
|
||||
float scale = post_op.depthwise.weights_data[index_c];
|
||||
float shift = post_op.depthwise.biases_data[index_c];
|
||||
dst_value = dst_value * scale + shift;
|
||||
} else if (post_op.is_quantization()) {
|
||||
bool do_dequantization = post_op.quantization.alg ==
|
||||
alg_kind::quantization_quantize_dequantize;
|
||||
bool do_rounding = do_dequantization || isFloatCompatible(output_prec) ||
|
||||
i != p.len_ - 1;
|
||||
|
||||
auto quant = post_op.quantization;
|
||||
|
||||
float crop_low = quant.crop_low_data->shifts_[quant.crop_low_data->count_ == 1 ? 0 : index_c];
|
||||
float crop_high = quant.crop_high_data->shifts_[quant.crop_high_data->count_ == 1 ? 0 : index_c];
|
||||
float input_scale = quant.input_scale_data->scales_[quant.input_scale_data->count_ == 1 ? 0 : index_c];
|
||||
float input_shift = quant.input_shift_data->shifts_[quant.input_shift_data->count_ == 1 ? 0 : index_c];
|
||||
|
||||
dst_value = nstl::min(crop_high, nstl::max(crop_low, dst_value));
|
||||
dst_value = dst_value * input_scale + input_shift;
|
||||
|
||||
if (do_rounding) {
|
||||
dst_value = roundf(dst_value);
|
||||
}
|
||||
|
||||
if (do_dequantization) {
|
||||
float output_scale = quant.output_scale_data->scales_[quant.output_scale_data->count_ == 1 ? 0 : index_c];
|
||||
float output_shift = quant.output_shift_data->shifts_[quant.output_shift_data->count_ == 1 ? 0 : index_c];
|
||||
dst_value = dst_value * output_scale + output_shift;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool MKLDNNResampleNode::created() const {
|
||||
return getType() == Resample;
|
||||
}
|
||||
|
||||
REG_MKLDNN_PRIM_FOR(MKLDNNResampleNode, Resample);
|
@ -1,109 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ie_common.h>
|
||||
#include <mkldnn_node.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
struct jit_resample_config_params {
|
||||
bool planar_layout;
|
||||
bool nhwc_format;
|
||||
mkldnn::memory::data_type src_dt;
|
||||
mkldnn::memory::data_type dst_dt;
|
||||
int src_data_size;
|
||||
int dst_data_size;
|
||||
};
|
||||
|
||||
struct jit_resample_call_args {
|
||||
const void *src;
|
||||
const int *index;
|
||||
void *dst;
|
||||
size_t src_stride;
|
||||
size_t index_stride;
|
||||
size_t dst_stride;
|
||||
size_t work_amount;
|
||||
size_t oc_off;
|
||||
};
|
||||
|
||||
struct jit_uni_resample_nearest_kernel {
|
||||
void (*ker_)(const jit_resample_call_args *);
|
||||
|
||||
void operator()(const jit_resample_call_args *args) {
|
||||
assert(ker_);
|
||||
ker_(args);
|
||||
}
|
||||
|
||||
explicit jit_uni_resample_nearest_kernel(jit_resample_config_params jcp, const mkldnn_primitive_attr &attr) : ker_(nullptr), jcp_(jcp), attr_(attr) {}
|
||||
virtual ~jit_uni_resample_nearest_kernel() {}
|
||||
|
||||
jit_resample_config_params jcp_;
|
||||
const mkldnn_primitive_attr &attr_;
|
||||
};
|
||||
|
||||
struct jit_uni_resample_linear_kernel {
|
||||
void (*ker_)(const jit_resample_call_args *);
|
||||
|
||||
void operator()(const jit_resample_call_args *args) {
|
||||
assert(ker_);
|
||||
ker_(args);
|
||||
}
|
||||
|
||||
explicit jit_uni_resample_linear_kernel(jit_resample_config_params jcp, const mkldnn_primitive_attr &attr) : ker_(nullptr), jcp_(jcp), attr_(attr) {}
|
||||
virtual ~jit_uni_resample_linear_kernel() {}
|
||||
|
||||
jit_resample_config_params jcp_;
|
||||
const mkldnn_primitive_attr &attr_;
|
||||
};
|
||||
|
||||
|
||||
class MKLDNNResampleNode : public MKLDNNNode {
|
||||
public:
|
||||
MKLDNNResampleNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache);
|
||||
~MKLDNNResampleNode() override = default;
|
||||
|
||||
void getSupportedDescriptors() override;
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void createPrimitive() override;
|
||||
bool created() const override;
|
||||
void execute(mkldnn::stream strm) override;
|
||||
bool canBeInPlace() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename in_data_t, typename out_data_t>
|
||||
void NearestNeighbor_PLN(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW,
|
||||
float fx, float fy, float fz, int OD, int OH, int OW);
|
||||
template <typename in_data_t, typename out_data_t>
|
||||
void NearestNeighbor_BLK(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW,
|
||||
float fx, float fy, float fz, int OD, int OH, int OW);
|
||||
template <typename in_data_t, typename out_data_t>
|
||||
void LinearInterpolation(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW,
|
||||
float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias);
|
||||
void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false);
|
||||
inline void apply_post_ops_scalar(float &dst_value, int index_c);
|
||||
|
||||
int blk_size;
|
||||
|
||||
std::string type;
|
||||
bool antialias;
|
||||
float factor;
|
||||
|
||||
mkldnn::primitive_attr attr;
|
||||
std::vector<MKLDNNMemoryPtr> PostOpsIntBlobMemory;
|
||||
|
||||
InferenceEngine::Precision input_prec, output_prec;
|
||||
size_t src_data_size, dst_data_size;
|
||||
|
||||
std::shared_ptr<jit_uni_resample_nearest_kernel> resample_nearest_kernel;
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -55,6 +55,7 @@ protected:
|
||||
std::vector<int64_t> axes;
|
||||
std::vector<float> scales;
|
||||
std:tie(mode, shapeCalcMode, coordinateTransformMode, nearestMode, antialias, padBegin, padEnd, cubeCoef, axes, scales) = interpolateParams;
|
||||
inPrc = outPrc = netPrecision;
|
||||
|
||||
using ShapeCalcMode = ngraph::op::v4::Interpolate::ShapeCalcMode;
|
||||
|
||||
@ -81,6 +82,8 @@ protected:
|
||||
interpolate->get_rt_info() = getCPUInfo();
|
||||
const ngraph::ResultVector results{std::make_shared<ngraph::opset3::Result>(interpolate)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "interpolate");
|
||||
|
||||
selectedType = getPrimitiveType() + "_" + inPrc.name();
|
||||
}
|
||||
};
|
||||
|
||||
@ -99,7 +102,6 @@ std::vector<CPUSpecificParams> filterCPUInfoForDevice() {
|
||||
if (with_cpu_x86_avx512f()) {
|
||||
resCPUParams.push_back(CPUSpecificParams{{nChw16c, x, x}, {nChw16c}, {"jit_avx512"}, "jit_avx512_FP32"});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nhwc, x, x}, {nhwc}, {"jit_avx512"}, "jit_avx512_FP32"});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nchw, x, x}, {nchw}, {"jit_avx2"}, "jit_avx2_FP32"});
|
||||
} else if (with_cpu_x86_avx2()) {
|
||||
resCPUParams.push_back(CPUSpecificParams{{nChw8c, x, x}, {nChw8c}, {"jit_avx2"}, "jit_avx2_FP32"});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nhwc, x, x}, {nhwc}, {"jit_avx2"}, "jit_avx2_FP32"});
|
||||
@ -115,7 +117,8 @@ std::vector<CPUSpecificParams> filterCPUInfoForDevice() {
|
||||
/* ========== */
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::BF16
|
||||
};
|
||||
|
||||
const std::vector<ngraph::op::v4::Interpolate::CoordinateTransformMode> coordinateTransformModes = {
|
||||
|
@ -1,254 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
|
||||
#include "test_graph.hpp"
|
||||
|
||||
#include "single_layer_common.hpp"
|
||||
#include "tests_common.hpp"
|
||||
#include <ie_core.hpp>
|
||||
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace std;
|
||||
using namespace mkldnn;
|
||||
|
||||
|
||||
struct interp_test_params {
|
||||
struct {
|
||||
size_t n;
|
||||
size_t c;
|
||||
size_t h;
|
||||
size_t w;
|
||||
} in;
|
||||
|
||||
struct {
|
||||
size_t h;
|
||||
size_t w;
|
||||
} out;
|
||||
|
||||
int pad_beg;
|
||||
int pad_end;
|
||||
|
||||
size_t num_prim_desc;
|
||||
|
||||
int selectedType;
|
||||
|
||||
std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
|
||||
};
|
||||
|
||||
void interpolate(const int N, const int C, const float *src, const int x1, const int y1, const int IH_pad, const int IW_pad,
|
||||
const int IH, const int IW, float *dst, const int x2, const int y2, const int OH_pad, const int OW_pad, const int OH, const int OW) {
|
||||
if (IH_pad == OH_pad && IW_pad == OW_pad) {
|
||||
for (int i = 0; i < N * C * OH * OW; i++) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const float rh = (OH_pad > 1) ? static_cast<float>(IH_pad - 1) / (OH_pad - 1) : 0.0f;
|
||||
const float rw = (OW_pad > 1) ? static_cast<float>(IW_pad - 1) / (OW_pad - 1) : 0.0f;
|
||||
|
||||
const int block_size = 1;
|
||||
|
||||
// Align channel number to block size to deal with channels padding in IE with multiple blobs
|
||||
int CB = (C + block_size - 1) & (-block_size); // CB=n*block_size, i.e.:c=15,(block_size=8), then CB=16, CH=2
|
||||
|
||||
int CH = (C + block_size - 1) / block_size; // number of block:(n)
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int cb = 0; cb < CH; ++cb) {
|
||||
for (int h = 0; h < OH_pad; ++h) {
|
||||
const float *psrc = src + n * CB * IH * IW; // should be nChw8c(16c) data format
|
||||
|
||||
float fh = rh * h;
|
||||
int ih0 = static_cast<int>(fh);
|
||||
int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0;
|
||||
|
||||
float h_lambda0 = fh - ih0;
|
||||
float h_lambda1 = 1.0f - h_lambda0;
|
||||
|
||||
for (int w = 0; w < OW_pad; ++w) {
|
||||
float fw = rw * w;
|
||||
int iw0 = static_cast<int>(fw);
|
||||
int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0;
|
||||
|
||||
float w_lambda0 = fw - iw0;
|
||||
float w_lambda1 = 1.0f - w_lambda0;
|
||||
|
||||
const float *psrc00 =
|
||||
psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw0) * block_size;
|
||||
const float *psrc01 =
|
||||
psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw1) * block_size;
|
||||
const float *psrc10 =
|
||||
psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw0) * block_size;
|
||||
const float *psrc11 =
|
||||
psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw1) * block_size;
|
||||
|
||||
float *pdst = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size +
|
||||
(x2 + w) * block_size;
|
||||
|
||||
for (int c = 0; c < block_size; ++c) {
|
||||
pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) +
|
||||
h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_t>
|
||||
void ref_interp(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, interp_test_params prm) {
|
||||
int IB = static_cast<int>(src.getTensorDesc().getDims()[0]);
|
||||
int IC = static_cast<int>(src.getTensorDesc().getDims()[1]);
|
||||
int IH = static_cast<int>(src.getTensorDesc().getDims()[2]);
|
||||
int IW = static_cast<int>(src.getTensorDesc().getDims()[3]);
|
||||
|
||||
int OH = static_cast<int>(dst.getTensorDesc().getDims()[2]);
|
||||
int OW = static_cast<int>(dst.getTensorDesc().getDims()[3]);
|
||||
|
||||
int IH_pad = IH + prm.pad_beg + prm.pad_end;
|
||||
int IW_pad = IW + prm.pad_beg + prm.pad_end;
|
||||
|
||||
const data_t *src_data = src.readOnly();
|
||||
data_t *dst_data = dst.data();
|
||||
|
||||
interpolate(IB, IC, src_data, -prm.pad_beg, -prm.pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW);
|
||||
}
|
||||
|
||||
class MKLDNNCPUExtInterpTests: public TestsCommon, public WithParamInterface<interp_test_params> {
|
||||
std::string model_t = R"V0G0N(
|
||||
<Net Name="Convolution_Only" version="2" precision="FP32" batch="1">
|
||||
<layers>
|
||||
<layer name="in1" type="Input" precision="FP32" id="0">
|
||||
<output>
|
||||
<port id="0">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_IH_</dim>
|
||||
<dim>_IW_</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="interp1" id="1" type="Interp" precision="FP32">
|
||||
<data pad_beg="_PB_" pad_end="_PE_" height="_OH_" width="_OW_"/>
|
||||
|
||||
<input>
|
||||
<port id="1">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_IH_</dim>
|
||||
<dim>_IW_</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_OH_</dim>
|
||||
<dim>_OW_</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
|
||||
</edges>
|
||||
</Net>
|
||||
)V0G0N";
|
||||
|
||||
std::string getModel(interp_test_params p) {
|
||||
std::string model = model_t;
|
||||
REPLACE_WITH_NUM(model, "_IW_", p.in.w);
|
||||
REPLACE_WITH_NUM(model, "_IH_", p.in.h);
|
||||
REPLACE_WITH_NUM(model, "_IC_", p.in.c);
|
||||
REPLACE_WITH_NUM(model, "_IN_", p.in.n);
|
||||
|
||||
REPLACE_WITH_NUM(model, "_OH_", p.out.h);
|
||||
REPLACE_WITH_NUM(model, "_OW_", p.out.w);
|
||||
|
||||
REPLACE_WITH_NUM(model, "_PB_", p.pad_beg);
|
||||
REPLACE_WITH_NUM(model, "_PE_", p.pad_end);
|
||||
return model;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void TearDown() {
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
try {
|
||||
TestsCommon::SetUp();
|
||||
interp_test_params p = ::testing::WithParamInterface<interp_test_params>::GetParam();
|
||||
std::string model = getModel(p);
|
||||
|
||||
InferenceEngine::Core core;
|
||||
InferenceEngine::CNNNetwork network;
|
||||
ASSERT_NO_THROW(network = core.ReadNetwork(model, InferenceEngine::Blob::CPtr()));
|
||||
|
||||
MKLDNNGraphTestClass graph;
|
||||
graph.CreateGraph(network);
|
||||
|
||||
auto& nodes = graph.getNodes();
|
||||
nodes = graph.getNodes();
|
||||
for (auto &node : nodes) {
|
||||
if (node->getName() == "interp1") {
|
||||
ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
|
||||
for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
|
||||
p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
|
||||
}
|
||||
ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
|
||||
ASSERT_EQ(p.selectedType,
|
||||
node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
|
||||
}
|
||||
}
|
||||
ASSERT_LE(4, nodes.size());
|
||||
|
||||
InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
|
||||
|
||||
InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>({InferenceEngine::Precision::FP32, dims_src, InferenceEngine::NCHW});
|
||||
src->allocate();
|
||||
fill_data(src->buffer(), src->size());
|
||||
|
||||
auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
|
||||
|
||||
if (srcPtr == nullptr)
|
||||
FAIL() << "Cannot cast blob to TBlob<float>.";
|
||||
|
||||
InferenceEngine::BlobMap srcs;
|
||||
srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
|
||||
|
||||
InferenceEngine::OutputsDataMap out;
|
||||
out = network.getOutputsInfo();
|
||||
InferenceEngine::BlobMap outputBlobs;
|
||||
|
||||
std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
|
||||
|
||||
InferenceEngine::TBlob<float>::Ptr output;
|
||||
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
|
||||
output->allocate();
|
||||
outputBlobs[item.first] = output;
|
||||
|
||||
graph.Infer(srcs, outputBlobs);
|
||||
|
||||
|
||||
InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
|
||||
dst_ref.allocate();
|
||||
ref_interp(*srcPtr, dst_ref, p);
|
||||
compare(*output, dst_ref);
|
||||
} catch (const InferenceEngine::details::InferenceEngineException &e) {
|
||||
FAIL() << e.what();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(MKLDNNCPUExtInterpTests, TestsInterp) {}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
TestsInterp, MKLDNNCPUExtInterpTests,
|
||||
::testing::Values(
|
||||
interp_test_params{{1, 256, 1, 1}, {33, 65}, 0, 0, 1, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
interp_test_params{{6, 128, 320, 320}, {23, 38}, 0, 0, 1, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
interp_test_params{{1, 2, 33, 65}, {33, 65}, 0, 0, 1, MKLDNNPlugin::impl_desc_type::unknown }));
|
@ -1,367 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_graph.hpp"
|
||||
|
||||
#include "single_layer_common.hpp"
|
||||
#include "tests_common.hpp"
|
||||
#include "ir_gen_helper.hpp"
|
||||
#include <ie_core.hpp>
|
||||
|
||||
#include <nodes/base.hpp>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace ::testing;
|
||||
using namespace std;
|
||||
using namespace single_layer_tests;
|
||||
|
||||
using namespace Extensions;
|
||||
using namespace ::Cpu;
|
||||
|
||||
struct resample_test_params {
|
||||
std::vector<size_t> in_dims;
|
||||
|
||||
float factor;
|
||||
int antialias;
|
||||
std::string type;
|
||||
|
||||
size_t num_prim_desc;
|
||||
bool isBlockedFormat;
|
||||
int selectedType;
|
||||
|
||||
std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
|
||||
};
|
||||
|
||||
|
||||
static inline float triangleCoeff(float x) {
|
||||
return max(0.0f, 1 - std::abs(x));
|
||||
}
|
||||
|
||||
extern InferenceEngine::IExtensionPtr make_FakeExtensions();
|
||||
|
||||
template <typename data_t>
|
||||
void ref_resample(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, resample_test_params prm) {
|
||||
const data_t *src_data = src.readOnly();
|
||||
data_t *dst_data = dst.data();
|
||||
|
||||
size_t ndims = prm.in_dims.size();
|
||||
|
||||
size_t N = prm.in_dims[0];
|
||||
size_t C = prm.in_dims[1];
|
||||
size_t ID = ndims == 5 ? prm.in_dims[ndims - 3] : 1;
|
||||
size_t IH = prm.in_dims[ndims - 2];
|
||||
size_t IW = prm.in_dims[ndims - 1];
|
||||
size_t OD = ndims == 5 ? ID / prm.factor : 1;
|
||||
size_t OH = IH / prm.factor;
|
||||
size_t OW = IW / prm.factor;
|
||||
|
||||
float fx = static_cast<float>(IW) / static_cast<float>(OW);
|
||||
float fy = static_cast<float>(IH) / static_cast<float>(OH);
|
||||
float fz = static_cast<float>(ID) / static_cast<float>(OD);
|
||||
|
||||
if (prm.type == "caffe.ResampleParameter.NEAREST") {
|
||||
for (size_t b = 0; b < N; b++) {
|
||||
for (size_t c = 0; c < C; c++) {
|
||||
const float *in_ptr = src_data + IW * IH * ID * C * b + IW * IH * ID * c;
|
||||
float *out_ptr = dst_data + OW * OH * OD * C * b + OW * OH * OD * c;
|
||||
|
||||
for (size_t oz = 0; oz < OD; oz++) {
|
||||
for (size_t oy = 0; oy < OH; oy++) {
|
||||
for (size_t ox = 0; ox < OW; ox++) {
|
||||
float ix = ox * fx;
|
||||
float iy = oy * fy;
|
||||
float iz = oz * fz;
|
||||
|
||||
size_t ix_r = static_cast<size_t>(std::floor(ix));
|
||||
size_t iy_r = static_cast<size_t>(std::floor(iy));
|
||||
size_t iz_r = static_cast<size_t>(std::floor(iz));
|
||||
|
||||
out_ptr[oz * OH * OW + oy * OW + ox] = in_ptr[iz_r * IH * IW + iy_r * IW + ix_r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (prm.type == "caffe.ResampleParameter.LINEAR") {
|
||||
size_t kernel_width = 2;
|
||||
bool isDownsample = (fx > 1) || (fy > 1) || (fz > 1);
|
||||
bool antialias = isDownsample && prm.antialias;
|
||||
|
||||
for (size_t b = 0; b < N; b++) {
|
||||
for (size_t c = 0; c < C; c++) {
|
||||
const float *in_ptr = src_data + IW * IH * ID * C * b + IW * IH * ID * c;
|
||||
float *out_ptr = dst_data + OW * OH * OD * C * b + OW * OH * OD * c;
|
||||
|
||||
for (size_t oz = 0; oz < OD; oz++) {
|
||||
for (size_t oy = 0; oy < OH; oy++) {
|
||||
for (size_t ox = 0; ox < OW; ox++) {
|
||||
float ix = ox * fx + fx / 2.0f - 0.5f;
|
||||
float iy = oy * fy + fy / 2.0f - 0.5f;
|
||||
float iz = oz * fz + fz / 2.0f - 0.5f;
|
||||
|
||||
int ix_r = static_cast<int>(round(ix));
|
||||
int iy_r = static_cast<int>(round(iy));
|
||||
int iz_r = static_cast<int>(round(iz));
|
||||
|
||||
float sum = 0;
|
||||
float wsum = 0;
|
||||
|
||||
float ax = 1.0f / (antialias ? fx : 1.0f);
|
||||
float ay = 1.0f / (antialias ? fy : 1.0f);
|
||||
float az = 1.0f / (antialias ? fz : 1.0f);
|
||||
|
||||
int rx = (fx < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ax));
|
||||
int ry = (fy < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / ay));
|
||||
int rz = (fz < 1.0f) ? 2 : static_cast<int>(ceil(static_cast<float>(kernel_width) / az));
|
||||
|
||||
for (int z = iz_r - rz; z <= iz_r + rz; z++) {
|
||||
for (int y = iy_r - ry; y <= iy_r + ry; y++) {
|
||||
for (int x = ix_r - rx; x <= ix_r + rx; x++) {
|
||||
if (z < 0 || y < 0 || x < 0 || z >= static_cast<int>(ID) ||y >= static_cast<int>(IH) || x >= static_cast<int>(IW))
|
||||
continue;
|
||||
|
||||
float dx = ix - x;
|
||||
float dy = iy - y;
|
||||
float dz = iz - z;
|
||||
|
||||
float w = ax * triangleCoeff(ax * dx) * ay * triangleCoeff(ay * dy) * az * triangleCoeff(az * dz);
|
||||
|
||||
sum += w * in_ptr[z * IH * IW + y * IW + x];
|
||||
wsum += w;
|
||||
}
|
||||
}
|
||||
}
|
||||
out_ptr[oz * OH * OW + oy * OW + ox] = (!wsum) ? 0 : (sum / wsum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert(!"Unsupported resample operation type");
|
||||
}
|
||||
}
|
||||
|
||||
class MKLDNNCPUExtResampleTests: public TestsCommon, public WithParamInterface<resample_test_params> {
|
||||
std::string model_t = R"V0G0N(
|
||||
<Net Name="Resample_net" version="2" precision="FP32" batch="1">
|
||||
<layers>
|
||||
<layer name="in1" type="Input" precision="FP32" id="0">
|
||||
<output>
|
||||
<port id="0">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_ID_</dim>
|
||||
<dim>_IH_</dim>
|
||||
<dim>_IW_</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="fakeLayer" id="1" type="_FL_" precision="FP32">
|
||||
<input>
|
||||
<port id="1">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_ID_</dim>
|
||||
<dim>_IH_</dim>
|
||||
<dim>_IW_</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_ID_</dim>
|
||||
<dim>_IH_</dim>
|
||||
<dim>_IW_</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="resample" id="2" type="Resample" precision="FP32">
|
||||
<data antialias="_AN_" factor="_F_" type="_T_"/>
|
||||
<input>
|
||||
<port id="3">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_ID_</dim>
|
||||
<dim>_IH_</dim>
|
||||
<dim>_IW_</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="4">
|
||||
<dim>_IN_</dim>
|
||||
<dim>_IC_</dim>
|
||||
<dim>_OD_</dim>
|
||||
<dim>_OH_</dim>
|
||||
<dim>_OW_</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
|
||||
<edge from-layer="1" from-port="2" to-layer="2" to-port="3"/>
|
||||
</edges>
|
||||
</Net>
|
||||
)V0G0N";
|
||||
|
||||
std::string getModel(resample_test_params p) {
|
||||
std::string model = model_t;
|
||||
|
||||
auto dims_size = p.in_dims.size();
|
||||
if (dims_size == 4) {
|
||||
REMOVE_LINE(model, "<dim>_ID_</dim>");
|
||||
REMOVE_LINE(model, "<dim>_OD_</dim>");
|
||||
}
|
||||
|
||||
if (p.isBlockedFormat)
|
||||
REPLACE_WITH_STR(model, "_FL_", "FakeLayerBLK");
|
||||
else
|
||||
REPLACE_WITH_STR(model, "_FL_", "FakeLayerPLN");
|
||||
|
||||
REPLACE_WITH_NUM(model, "_IN_", p.in_dims[0]);
|
||||
REPLACE_WITH_NUM(model, "_IC_", p.in_dims[1]);
|
||||
if (dims_size == 5)
|
||||
REPLACE_WITH_NUM(model, "_ID_", p.in_dims[dims_size - 3]);
|
||||
REPLACE_WITH_NUM(model, "_IH_", p.in_dims[dims_size - 2]);
|
||||
REPLACE_WITH_NUM(model, "_IW_", p.in_dims[dims_size - 1]);
|
||||
|
||||
if (dims_size == 5)
|
||||
REPLACE_WITH_NUM(model, "_OD_", (int)(p.in_dims[dims_size - 3] / p.factor));
|
||||
REPLACE_WITH_NUM(model, "_OH_", (int)(p.in_dims[dims_size - 2] / p.factor));
|
||||
REPLACE_WITH_NUM(model, "_OW_", (int)(p.in_dims[dims_size - 1] / p.factor));
|
||||
|
||||
REPLACE_WITH_NUM(model, "_AN_", p.antialias);
|
||||
REPLACE_WITH_NUM(model, "_F_", p.factor);
|
||||
REPLACE_WITH_STR(model, "_T_", p.type);
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void TearDown() {
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
try {
|
||||
TestsCommon::SetUp();
|
||||
resample_test_params p = ::testing::WithParamInterface<resample_test_params>::GetParam();
|
||||
std::string model = getModel(p);
|
||||
|
||||
MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr(new MKLDNNPlugin::MKLDNNExtensionManager());
|
||||
auto defaultExtensions = std::make_shared<InferenceEngine::Extensions::Cpu::MKLDNNExtensions>();
|
||||
extMgr->AddExtension(defaultExtensions);
|
||||
extMgr->AddExtension(make_FakeExtensions());
|
||||
|
||||
InferenceEngine::Core core;
|
||||
InferenceEngine::CNNNetwork network;
|
||||
ASSERT_NO_THROW(network = core.ReadNetwork(model, InferenceEngine::Blob::CPtr()));
|
||||
|
||||
MKLDNNGraphTestClass graph;
|
||||
graph.CreateGraph(network, extMgr);
|
||||
|
||||
auto& nodes = graph.getNodes();
|
||||
nodes = graph.getNodes();
|
||||
|
||||
for (auto &node : nodes) {
|
||||
if (node->getName() == "resample") {
|
||||
ASSERT_EQ(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
|
||||
for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
|
||||
p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
|
||||
}
|
||||
ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
|
||||
ASSERT_EQ(p.selectedType,
|
||||
node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEngine::SizeVector dims_src = p.in_dims;
|
||||
|
||||
InferenceEngine::Layout layout = InferenceEngine::ANY;
|
||||
switch (p.in_dims.size()) {
|
||||
case 4: layout = InferenceEngine::NCHW; break;
|
||||
case 5: layout = InferenceEngine::NCDHW; break;
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>({InferenceEngine::Precision::FP32, dims_src, layout});
|
||||
src->allocate();
|
||||
fill_data(src->buffer(), src->size());
|
||||
|
||||
auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
|
||||
|
||||
if (srcPtr == nullptr)
|
||||
FAIL() << "Cannot cast blob to TBlob<float>.";
|
||||
|
||||
InferenceEngine::BlobMap srcs;
|
||||
srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
|
||||
|
||||
InferenceEngine::OutputsDataMap out;
|
||||
out = network.getOutputsInfo();
|
||||
InferenceEngine::BlobMap outputBlobs;
|
||||
|
||||
std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
|
||||
|
||||
InferenceEngine::TBlob<float>::Ptr output;
|
||||
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
|
||||
output->allocate();
|
||||
outputBlobs[item.first] = output;
|
||||
|
||||
graph.Infer(srcs, outputBlobs);
|
||||
|
||||
InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
|
||||
dst_ref.allocate();
|
||||
ref_resample(*srcPtr, dst_ref, p);
|
||||
compare(*output, dst_ref);
|
||||
} catch (const InferenceEngine::details::InferenceEngineException &e) {
|
||||
FAIL() << e.what();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(MKLDNNCPUExtResampleTests, TestsResample) {}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
TestsResample, MKLDNNCPUExtResampleTests,
|
||||
::testing::Values(
|
||||
resample_test_params{{2, 64, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 15, 25}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 15, 25}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
// 5D nearest
|
||||
resample_test_params{{2, 64, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 64, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 3, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
// 5D linear
|
||||
resample_test_params{{2, 15, 15, 10, 20}, 9.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 15, 15, 10, 20}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 15, 15, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 2, 15, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 15, 15, 10, 20}, 9.f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 15, 15, 10, 20}, 1.f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 8, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
|
||||
resample_test_params{{2, 2, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }));
|
@ -222,8 +222,8 @@ void op::v4::Interpolate::validate_and_infer_types()
|
||||
element::Type input_et = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_et == element::Type_t::f32 || input_et == element::Type_t::f16 ||
|
||||
input_et == element::Type_t::i8,
|
||||
"Input element type must be f32, f16, or i8");
|
||||
input_et == element::Type_t::i8 || input_et == element::Type_t::bf16,
|
||||
"Input element type must be f32, f16, bf16 or i8");
|
||||
|
||||
PartialShape input_shape = PartialShape(get_input_partial_shape(0));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user