diff --git a/inference-engine/src/mkldnn_plugin/CMakeLists.txt b/inference-engine/src/mkldnn_plugin/CMakeLists.txt index 2b0743b1280..d33e3385f68 100644 --- a/inference-engine/src/mkldnn_plugin/CMakeLists.txt +++ b/inference-engine/src/mkldnn_plugin/CMakeLists.txt @@ -36,7 +36,6 @@ set(LAYERS ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tensoriterator_node.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tile_node.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_resample_node.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_normalize_node.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_scatter_update_node.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_interpolate_node.cpp @@ -93,7 +92,6 @@ set(LAYERS ${CMAKE_CURRENT_SOURCE_DIR}/nodes/unsqueeze.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/common/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/common/emitter.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/nodes/interp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/jit_eltwise_emitters.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/jit_mkldnn_emitters.cpp diff --git a/inference-engine/src/mkldnn_plugin/bf16transformer.h b/inference-engine/src/mkldnn_plugin/bf16transformer.h index 3f302348e47..02cf8316610 100644 --- a/inference-engine/src/mkldnn_plugin/bf16transformer.h +++ b/inference-engine/src/mkldnn_plugin/bf16transformer.h @@ -14,7 +14,7 @@ namespace MKLDNNPlugin { class BF16Transformer { const InferenceEngine::details::caseless_set _initbf16 = - { "convolution", "fullyconnected", "innerproduct", "gemm", "RegionYolo" }; + { "convolution", "fullyconnected", "innerproduct", "gemm", "RegionYolo", "Interpolate" }; const InferenceEngine::details::caseless_set _complementbf16 = { "relu", "tanh", "elu", "square", "abs", "sqrt", "linear", "bounded_relu", "soft_relu", "normalize", "sigmoid", "ReLU6", "not", "activation", "HSwish", "mish", "logistic", "mod", "resample", diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp index d5c4e4db1db..d6017ff5c1f 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp @@ -15,7 +15,6 @@ #include "nodes/mkldnn_quantize_node.h" #include "nodes/mkldnn_mvn_node.h" #include -#include "nodes/mkldnn_resample_node.h" #include "nodes/mkldnn_interpolate_node.h" #include "nodes/mkldnn_input_node.h" @@ -123,9 +122,6 @@ void MKLDNNGraphOptimizer::ApplyCommonGraphOptimizations(MKLDNNGraph &graph) { FuseMVNAndSimpleOperation(graph); graph.RemoveDroppedNodes(); - FuseResampleAndSimpleOperation(graph); - graph.RemoveDroppedNodes(); - FuseInterpolateAndSimpleOperation(graph); graph.RemoveDroppedNodes(); @@ -1491,74 +1487,6 @@ void MKLDNNGraphOptimizer::FuseMVNAndSimpleOperation(MKLDNNGraph &graph) { } } -void MKLDNNGraphOptimizer::FuseResampleAndSimpleOperation(MKLDNNGraph &graph) { - auto& graphNodes = graph.GetNodes(); - - auto isSutableParentNode = [](MKLDNNNodePtr node) { - bool isSutableResample = (node->getType() == Resample) && (node->inDims[0].ndims() == 4 || node->inDims[0].ndims() == 5); - - if (isSutableResample) { - auto *resampleLayer = node->getCnnLayer().get(); - if (resampleLayer == nullptr) - THROW_IE_EXCEPTION << "Cannot get Resample layer " << node->getName(); - - return node->getChildEdges().size() == 1 && resampleLayer->GetParamAsString("type") == "caffe.ResampleParameter.NEAREST"; - } else { - return false; - } - }; - - auto isSutableChildNode = [](MKLDNNNodePtr node) { - if (!node->getCnnLayer()) - return false; - - if (node->getType() == Quantize) { - auto* quantizeNode = dynamic_cast(node.get()); - if (quantizeNode == nullptr) - THROW_IE_EXCEPTION << "Cannot get quantize layer " << node->getName(); - return !quantizeNode->isBinarization(); - } else if (node->getType() == Eltwise) { - auto *eltwiseNode = dynamic_cast(node.get()); - if (eltwiseNode == nullptr) - THROW_IE_EXCEPTION << "Cannot get Eltwise node " << node->getName(); - return eltwiseNode->getOpType() == Relu || - eltwiseNode->getOpType() == MulAdd; - } - - return false; - }; - - auto parent = graphNodes.begin(); - while (parent != graphNodes.end()) { - auto parentNode = *parent; - if (!isSutableParentNode(parentNode)) { - parent++; - continue; - } - - auto childNode = parentNode->getChildEdgeAt(0)->getChild(); - if (!isSutableChildNode(childNode)) { - parent++; - continue; - } - - parentNode->fuseWith(childNode); - - if (childNode->getType() == Quantize || childNode->getType() == Eltwise) { - auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { - auto p_edge = parentEdge.lock(); - if (p_edge->getParent()->getType() == Resample) - continue; - - removeEdge(graph, p_edge); - } - } - - graph.DropNode(childNode); - } -} - void MKLDNNGraphOptimizer::FuseInterpolateAndSimpleOperation(MKLDNNGraph &graph) { auto& graphNodes = graph.GetNodes(); diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.h b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.h index 025b79c9b7e..e74786c3391 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.h @@ -37,7 +37,6 @@ private: void FuseConvolutionSumAndConvolutionSumActivation(MKLDNNGraph &graph); #endif void FuseMVNAndSimpleOperation(MKLDNNGraph &graph); - void FuseResampleAndSimpleOperation(MKLDNNGraph &graph); void FuseInterpolateAndSimpleOperation(MKLDNNGraph &graph); void FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph); void RemoveIdentityOperator(MKLDNNGraph& graph); diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index 55abb10a6e3..a316784c17b 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -39,7 +39,6 @@ #include #include #include -#include #include #include #include @@ -123,7 +122,6 @@ static const InferenceEngine::details::caseless_unordered_map { "Memory", MemoryOutput }, // for construction from layer ctor { "Convert", Convert }, { "MVN", MVN}, - { "Resample", Resample}, { "Normalize", Normalize}, { "ScatterUpdate", ScatterUpdate}, { "ScatterElementsUpdate", ScatterElementsUpdate}, diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.h b/inference-engine/src/mkldnn_plugin/mkldnn_node.h index b804f547177..1319a713a7d 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.h @@ -66,7 +66,6 @@ enum Type { TensorIterator, Convert, MVN, - Resample, Normalize, ScatterUpdate, ScatterElementsUpdate, @@ -162,8 +161,6 @@ static std::string NameFromType(Type type) { return "TensorIterator"; case Convert: return "Convert"; - case Resample: - return "Resample"; case Normalize: return "Normalize"; case ScatterUpdate: diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp index 1f4553a3ca3..582e0f27c41 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -52,6 +53,7 @@ #include #include #include +#include #include #include #include @@ -200,8 +202,10 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf) pass_config->disable(); pass_config->disable(); pass_config->disable(); + pass_config->disable(); pass_config->enable(); + pass_config->enable(); manager.run_passes(nGraphFunc); diff --git a/inference-engine/src/mkldnn_plugin/nodes/interp.cpp b/inference-engine/src/mkldnn_plugin/nodes/interp.cpp deleted file mode 100644 index 6e2186899c3..00000000000 --- a/inference-engine/src/mkldnn_plugin/nodes/interp.cpp +++ /dev/null @@ -1,432 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "base.hpp" -#include -#include -#include -#include -#include "ie_parallel.hpp" -#include "jit_generator.hpp" - -using namespace mkldnn::impl::cpu; -using namespace mkldnn::impl::utils; - -namespace InferenceEngine { -namespace Extensions { -namespace Cpu { - -#define GET_OFF(field) offsetof(jit_args_interp, field) - -struct jit_args_interp { - const float *src00; - const float *src01; - const float *src10; - const float *src11; - float *dst; - float *h_lambda0; - float *h_lambda1; - float *w_lambda0; - float *w_lambda1; -}; - -struct jit_uni_interp_kernel { - void (*ker_)(const jit_args_interp *); - - void operator()(const jit_args_interp *args) { assert(ker_); ker_(args); } - - jit_uni_interp_kernel() : ker_(nullptr) {} - virtual ~jit_uni_interp_kernel() {} -}; - -template -struct jit_uni_interp_kernel_f32 : public jit_uni_interp_kernel, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_interp_kernel_f32) - - jit_uni_interp_kernel_f32() : jit_uni_interp_kernel(), jit_generator() { - this->preamble(); - - mov(reg_src00, ptr[reg_params + GET_OFF(src00)]); - mov(reg_src01, ptr[reg_params + GET_OFF(src01)]); - mov(reg_src10, ptr[reg_params + GET_OFF(src10)]); - mov(reg_src11, ptr[reg_params + GET_OFF(src11)]); - mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); - mov(reg_h_lambda0, ptr[reg_params + GET_OFF(h_lambda0)]); - mov(reg_h_lambda1, ptr[reg_params + GET_OFF(h_lambda1)]); - mov(reg_w_lambda0, ptr[reg_params + GET_OFF(w_lambda0)]); - mov(reg_w_lambda1, ptr[reg_params + GET_OFF(w_lambda1)]); - - uni_vmovups(vmm_src00, ptr[reg_src00]); - uni_vmovups(vmm_src01, ptr[reg_src01]); - uni_vmovups(vmm_src10, ptr[reg_src10]); - uni_vmovups(vmm_src11, ptr[reg_src11]); - - uni_vbroadcastss(vmm_h_lambda0, ptr[reg_h_lambda0]); - uni_vbroadcastss(vmm_h_lambda1, ptr[reg_h_lambda1]); - uni_vbroadcastss(vmm_w_lambda0, ptr[reg_w_lambda0]); - uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]); - - if (isa != sse42) { - uni_vmulps(vmm_src01, vmm_src01, vmm_w_lambda0); - uni_vmulps(vmm_src11, vmm_src11, vmm_w_lambda0); - uni_vfmadd231ps(vmm_src01, vmm_w_lambda1, vmm_src00); - uni_vfmadd231ps(vmm_src11, vmm_w_lambda1, vmm_src10); - uni_vmulps(vmm_src01, vmm_src01, vmm_h_lambda1); - uni_vfmadd231ps(vmm_src01, vmm_h_lambda0, vmm_src11); - uni_vmovups(ptr[reg_dst], vmm_src01); - } else { - uni_vmulps(vmm_src01, vmm_src01, vmm_w_lambda0); - uni_vmulps(vmm_src11, vmm_src11, vmm_w_lambda0); - uni_vfmadd231ps(vmm_src01, vmm_w_lambda1, vmm_src00); - // uni_vfmadd231ps affects XMM (vmm_w_lambda1) register. Need to initialize again. - uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]); - uni_vfmadd231ps(vmm_src11, vmm_w_lambda1, vmm_src10); - uni_vmulps(vmm_src01, vmm_src01, vmm_h_lambda1); - uni_vfmadd231ps(vmm_src01, vmm_h_lambda0, vmm_src11); - uni_vmovups(ptr[reg_dst], vmm_src01); - - // Next 4 elements - size_t stride = 4 * sizeof(float); - - add(reg_src00, stride); - add(reg_src01, stride); - add(reg_src10, stride); - add(reg_src11, stride); - add(reg_dst, stride); - - uni_vmovups(vmm_src00, ptr[reg_src00]); - uni_vmovups(vmm_src01, ptr[reg_src01]); - uni_vmovups(vmm_src10, ptr[reg_src10]); - uni_vmovups(vmm_src11, ptr[reg_src11]); - - uni_vbroadcastss(vmm_h_lambda0, ptr[reg_h_lambda0]); - uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]); - - uni_vmulps(vmm_src01, vmm_src01, vmm_w_lambda0); - uni_vmulps(vmm_src11, vmm_src11, vmm_w_lambda0); - uni_vfmadd231ps(vmm_src01, vmm_w_lambda1, vmm_src00); - uni_vbroadcastss(vmm_w_lambda1, ptr[reg_w_lambda1]); - uni_vfmadd231ps(vmm_src11, vmm_w_lambda1, vmm_src10); - uni_vmulps(vmm_src01, vmm_src01, vmm_h_lambda1); - uni_vfmadd231ps(vmm_src01, vmm_h_lambda0, vmm_src11); - uni_vmovups(ptr[reg_dst], vmm_src01); - } - - this->postamble(); - ker_ = (decltype(ker_))this->getCode(); - } - -private: - using Vmm = typename conditional3::type; - size_t vlen = cpu_isa_traits::vlen; - - Xbyak::Reg64 reg_src00 = r8; - Xbyak::Reg64 reg_src01 = r9; - Xbyak::Reg64 reg_src10 = r10; - Xbyak::Reg64 reg_src11 = r11; - Xbyak::Reg64 reg_dst = rbp; - Xbyak::Reg64 reg_h_lambda0 = r12; - Xbyak::Reg64 reg_h_lambda1 = r13; - Xbyak::Reg64 reg_w_lambda0 = r14; - Xbyak::Reg64 reg_w_lambda1 = r15; - Xbyak::Reg64 reg_params = abi_param1; - - Vmm vmm_src00 = Vmm(0); - Vmm vmm_src01 = Vmm(1); - Vmm vmm_src10 = Vmm(2); - Vmm vmm_src11 = Vmm(3); - Vmm vmm_h_lambda0 = Vmm(4); - Vmm vmm_h_lambda1 = Vmm(5); - Vmm vmm_w_lambda0 = Vmm(6); - Vmm vmm_w_lambda1 = Vmm(7); - Vmm vmm_dst = Vmm(8); -}; - -class InterpImpl: public ExtLayerBase { -public: - explicit InterpImpl(const CNNLayer* layer) { - try { - if (layer->insData.size() != 1 || layer->outData.empty()) - THROW_IE_EXCEPTION << "Incorrect number of input/output edges!"; - - auto inData = layer->insData[0].lock(); - if (inData == nullptr) { - THROW_IE_EXCEPTION << "Layer '" << layer->name << "' has nullable input data."; - } - if (inData->getTensorDesc().getDims().size() != 4) - THROW_IE_EXCEPTION << "Interp supports only 4d blobs!"; - - // We don't read other parameters since they are needed only for dst reshape in caffe - pad_beg = layer->GetParamAsInt("pad_beg"); - pad_end = layer->GetParamAsInt("pad_end"); - align_corners = layer->GetParamAsBool("align_corners", true); - - ConfLayout blk_layout; - if (inData->getTensorDesc().getPrecision() == Precision::U8) { - LayerConfig config; - DataConfig dataConfigDct; - dataConfigDct.desc = TensorDesc(Precision::U8, inData->getTensorDesc().getDims(), Layout::NCHW); - config.inConfs.push_back(dataConfigDct); - - DataConfig dataConfigOut; - const SizeVector& out_dims = layer->outData[0]->getTensorDesc().getDims(); - SizeVector blocks = out_dims; - SizeVector order(blocks.size()); - SizeVector dimOffsets(blocks.size()); - SizeVector strides(blocks.size()); - size_t offset((std::numeric_limits::max)()); - for (size_t i = 0; i < order.size(); i++) { - strides[i] = (std::numeric_limits::max)(); - dimOffsets[i] = 0; - order[i] = i; - } - dataConfigOut.desc = TensorDesc(Precision::FP32, out_dims, { blocks, order, offset, dimOffsets, strides }); - config.outConfs.push_back(dataConfigOut); - config.dynBatchSupport = false; - confs.push_back(config); - } else { - if (mayiuse(avx512_common)) { - blk_layout = ConfLayout::BLK16; - interp_kernel.reset(new jit_uni_interp_kernel_f32()); - addConfig(layer, { DataConfigurator(blk_layout, Precision::FP32) }, { DataConfigurator(blk_layout, Precision::FP32) }); - } else if (mayiuse(avx2)) { - blk_layout = ConfLayout::BLK8; - interp_kernel.reset(new jit_uni_interp_kernel_f32()); - addConfig(layer, { DataConfigurator(blk_layout, Precision::FP32) }, { DataConfigurator(blk_layout, Precision::FP32) }); - } else { - blk_layout = ConfLayout::BLK8; - interp_kernel.reset(new jit_uni_interp_kernel_f32()); - addConfig(layer, { DataConfigurator(blk_layout, Precision::FP32) }, { DataConfigurator(blk_layout, Precision::FP32) }); - } - } - } catch (InferenceEngine::details::InferenceEngineException &ex) { - errorMsg = ex.what(); - } - } - - StatusCode init(LayerConfig& config, ResponseDesc *resp) noexcept override { - if (config.inConfs.size() != 1 || config.outConfs.size() != 1) { - strncpy(resp->msg, "Interp layer has invalid configs", sizeof(resp->msg)); - return GENERAL_ERROR; - } - - if (config.inConfs[0].desc.getDims().size() != 4) { - std::ostringstream result; - result << "Interp layer has invalid layout: " << config.inConfs[0].desc.getLayout(); - strncpy(resp->msg, result.str().c_str(), sizeof(resp->msg) - 1); - return GENERAL_ERROR; - } - - auto inPrecision = config.inConfs[0].desc.getPrecision(); - if (inPrecision != Precision::U8 && inPrecision != Precision::FP32) { - strncpy(resp->msg, "Interp layer has unsupported input precision", sizeof(resp->msg)); - return GENERAL_ERROR; - } - - if (config.outConfs[0].desc.getPrecision() != Precision::FP32) { - strncpy(resp->msg, "Interp layer has unsupported output precision", sizeof(resp->msg)); - return GENERAL_ERROR; - } - - return OK; - } - - StatusCode execute(std::vector& inputs, std::vector& outputs, - ResponseDesc *resp) noexcept override { -#ifdef WIN32 -#undef IN -#endif - size_t IN = inputs[0]->getTensorDesc().getDims()[0]; - size_t IH = inputs[0]->getTensorDesc().getDims()[2]; - size_t IW = inputs[0]->getTensorDesc().getDims()[3]; - size_t OH = outputs[0]->getTensorDesc().getDims()[2]; - size_t OW = outputs[0]->getTensorDesc().getDims()[3]; - - size_t IH_pad = IH + pad_beg + pad_end; - size_t IW_pad = IW + pad_beg + pad_end; - - auto *dst_data = outputs[0]->buffer().as() + outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding(); - - switch (inputs[0]->getTensorDesc().getPrecision()) { - case Precision::FP32: - { - const float* src_data = inputs[0]->cbuffer().as() + inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding(); - size_t IC = (inputs[0]->getTensorDesc().getLayout() == Layout::BLOCKED) - ? inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[1] * - inputs[0]->getTensorDesc().getBlockingDesc().getBlockDims()[4] - : IC = inputs[0]->getTensorDesc().getDims()[1]; - interpolate(IN, IC, src_data, - -pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW); - } - break; - case Precision::U8: - { - const uint8_t* src_data = inputs[0]->cbuffer().as() + inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding(); - size_t IC = inputs[0]->getTensorDesc().getDims()[1]; - interpolate_8u(inputs[0]->getTensorDesc().getLayout(), IN, IC, src_data, - -pad_beg, -pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW); - } - break; - default: - if (resp) { - std::string errorMsg = "Incorrect input precision. Only U8 or FP32 are supported!"; - errorMsg.copy(resp->msg, sizeof(resp->msg) - 1); - } - return GENERAL_ERROR; - } - - return OK; - } - -private: - int pad_beg; - int pad_end; - bool align_corners; - std::shared_ptr interp_kernel; - - void interpolate(const size_t N, const size_t C, - const float *src, const int x1, const int y1, - const int IH_pad, const int IW_pad, const size_t IH, const size_t IW, - float *dst, const int x2, const int y2, - const int OH_pad, const int OW_pad, const size_t OH, const size_t OW) { - if (IH_pad == OH_pad && IW_pad == OW_pad) { - for (size_t i = 0; i < N * C * OH * OW; i++) { - dst[i] = src[i]; - } - return; - } - - float rh; - float rw; - if (align_corners) { - rh = (OH_pad > 1) ? static_cast(IH_pad - 1) / (OH_pad - 1) : 0.0f; - rw = (OW_pad > 1) ? static_cast(IW_pad - 1) / (OW_pad - 1) : 0.0f; - } else { - rh = static_cast(IH_pad) / (OH_pad); - rw = static_cast(IW_pad) / (OW_pad); - } - - int block_size = 1; - if (interp_kernel) { - if (mayiuse(avx512_common)) { - block_size = 16; - } else { - block_size = 8; - } - } - - // Align channel number to block size to deal with channels padding in IE with multiple blobs - size_t CB = (C + block_size - 1) & (-block_size); - - size_t CH = (C + block_size - 1) / block_size; - - parallel_for3d(N, CH, OH_pad, [&](size_t n, size_t cb, size_t h) { - const float *psrc_n_cb = src + n * CB * IH * IW + cb * block_size * IW * IH; // n+cb src address - - // h is output h - float fh = rh * h; - // ih0 is higher input h position - int ih0 = static_cast(fh); - // ih1 is lower input h position - int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0; - - float h_lambda0 = fh - ih0; // for lower input h weight - float h_lambda1 = 1.0f - h_lambda0; // for higher input h weight - - const float *psrc_h0 = psrc_n_cb + (y1 + ih0) * IW * block_size + x1 * block_size; - const float *psrc_h1 = psrc_n_cb + (y1 + ih1) * IW * block_size + x1 * block_size; - float *pdst_h = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size + x2 * block_size; - - auto arg = jit_args_interp(); - arg.h_lambda0 = static_cast(&h_lambda0); - arg.h_lambda1 = static_cast(&h_lambda1); - for (int w = 0; w < OW_pad; ++w) { - float fw = rw * w; - int iw0 = static_cast(fw); - int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0; - - float w_lambda0 = fw - iw0; // for right input w weight - float w_lambda1 = 1.0f - w_lambda0; // for left input w weight - - const float *psrc00 = psrc_h0 + iw0 * block_size; - const float *psrc01 = psrc_h0 + iw1 * block_size; - const float *psrc10 = psrc_h1 + iw0 * block_size; - const float *psrc11 = psrc_h1 + iw1 * block_size; - - float *pdst = pdst_h + w * block_size; - - if (interp_kernel) { - arg.src00 = psrc00; - arg.src01 = psrc01; - arg.src10 = psrc10; - arg.src11 = psrc11; - arg.dst = pdst; - arg.w_lambda0 = static_cast(&w_lambda0); - arg.w_lambda1 = static_cast(&w_lambda1); - (*interp_kernel)(&arg); - } else { - for (int c = 0; c < block_size; ++c) { - pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) + - h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]); - } - } - } - }); - } - - void interpolate_8u(Layout layout, const size_t N, const size_t C, - const uint8_t *src, const int x1, const int y1, - const int IH_pad, const int IW_pad, const size_t IH, const size_t IW, - float *dst, const int x2, const int y2, - const int OH_pad, const int OW_pad, const size_t OH, const size_t OW) { - if (IH_pad == OH_pad && IW_pad == OW_pad) { - for (size_t i = 0; i < N * C * OH * OW; i++) { - dst[i] = static_cast(src[i]); - } - return; - } - - float rh; - float rw; - if (align_corners) { - rh = (OH_pad > 1) ? static_cast(IH_pad - 1) / (OH_pad - 1) : 0.0f; - rw = (OW_pad > 1) ? static_cast(IW_pad - 1) / (OW_pad - 1) : 0.0f; - } else { - rh = static_cast(IH_pad) / (OH_pad); - rw = static_cast(IW_pad) / (OW_pad); - } - - parallel_for3d(N, C, OH_pad, [&](size_t n, size_t cb, size_t h) { - const uint8_t *psrc = src + n * C * IH * IW; - - float fh = rh * h; - int ih0 = static_cast(fh); - int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0; - - float h_lambda0 = fh - ih0; - float h_lambda1 = 1.0f - h_lambda0; - - for (int w = 0; w < OW_pad; ++w) { - float fw = rw * w; - int iw0 = static_cast(fw); - int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0; - - float w_lambda0 = fw - iw0; - float w_lambda1 = 1.0f - w_lambda0; - - dst[n * C * OH * OW + cb * OW * OH + (y2 + h) * OW + (x2 + w)] = - h_lambda1 * (w_lambda1 * static_cast(psrc[cb * IW * IH + (y1 + ih0) * IW + (x1 + iw0)]) + - w_lambda0 * static_cast(psrc[cb * IW * IH + (y1 + ih0) * IW + (x1 + iw1)])) + - h_lambda0 * (w_lambda1 * static_cast(psrc[cb * IW * IH + (y1 + ih1) * IW + (x1 + iw0)]) + - w_lambda0 * static_cast(psrc[cb * IW * IH + (y1 + ih1) * IW + (x1 + iw1)])); - } - }); - } -}; - -REG_FACTORY_FOR(InterpImpl, Interp); - -} // namespace Cpu -} // namespace Extensions -} // namespace InferenceEngine diff --git a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp index ccfeba3bdc2..054a837eff3 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp +++ b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp @@ -64,7 +64,6 @@ MKLDNN_EXTENSION_NODE(TopKImpl, TopK); MKLDNN_EXTENSION_NODE(ShuffleChannelsImpl, ShuffleChannels); MKLDNN_EXTENSION_NODE(SpaceToDepthImpl, SpaceToDepth); MKLDNN_EXTENSION_NODE(PowerFileImpl, PowerFile); -MKLDNN_EXTENSION_NODE(InterpImpl, Interp); MKLDNN_EXTENSION_NODE(BatchToSpaceImpl, BatchToSpace); MKLDNN_EXTENSION_NODE(ExperimentalDetectronPriorGridGeneratorImpl, ExperimentalDetectronPriorGridGenerator); MKLDNN_EXTENSION_NODE(SimplerNMSImpl, SimplerNMS); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp index a2be1f5e80b..a2813b0331b 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp @@ -1872,7 +1872,6 @@ void MKLDNNInterpolateNode::buildTblLinearOnnx(SizeVector& srcDimPad5d, SizeVect size_t scratchLen = rnd_up(OW + OW + OH + OH, 16); int idxType = 2; indexTable.resize(idxType * scratchLen); - std::vector index(scratchLen, 0); int *indexLeft = static_cast(&indexTable[0]); int *indexRight = static_cast(&indexTable[OW]); int *indexTop = static_cast(&indexTable[2 * OW]); @@ -2320,7 +2319,7 @@ void MKLDNNInterpolateNode::NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr arg.src_ptr[0] = in_ptr_cbd + blk_size * IW * index_h[h] * srcDataSize; arg.index = static_cast(&(index_w_kernel[0])); arg.work_amount = static_cast(OW); - arg.oc_off = cb * blk_size; + arg.oc_off = cb * blk_size * sizeof(float); (*interpolateKernel)(&arg); } }); @@ -2351,7 +2350,7 @@ void MKLDNNInterpolateNode::NNPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, arg.src_ptr[0] = in_ptr; arg.dst = out_ptr; arg.index = static_cast(&index_kernel[0]); // need index_h and index_w in kernel, it's in continous memory so one param - arg.oc_off = static_cast(c); + arg.oc_off = static_cast(c * sizeof(float)); // work_amount is OH(out loop) and OW(inner loop), can get in kernel from jcp. (*interpolateKernel)(&arg); }); @@ -2391,7 +2390,7 @@ void MKLDNNInterpolateNode::linearOnnxPlanar(const uint8_t *in_ptr_, uint8_t *ou arg.weight_ptr[0] = static_cast(&weight[0]); arg.dst = out_ptr_nc; arg.work_amount = OW * OH; - arg.oc_off = c; + arg.oc_off = static_cast(c * sizeof(float)); (*interpolateKernel)(&arg); }); } @@ -2666,7 +2665,7 @@ void MKLDNNInterpolateNode::cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr arg.weight_ptr[0] = xFactor; arg.weight_ptr[1] = yFactor; arg.work_amount = static_cast(OW * OH); - arg.oc_off = static_cast(C); + arg.oc_off = static_cast(c * sizeof(float)); (*interpolateKernel)(&arg); }); } @@ -2788,7 +2787,7 @@ inline float MKLDNNInterpolateNode::coordTransToInput(int outCoord, float scale, } case InterpolateCoordTransMode::align_corners: { if (outShape > 1) - return outCoord * static_cast(inShape - 1) / static_cast(outShape - 1); + return outCoord * (static_cast(inShape - 1) / static_cast(outShape - 1)); else return 0; break; @@ -2844,10 +2843,9 @@ bool MKLDNNInterpolateNode::canFuse(const MKLDNNNodePtr& node) const { return false; }; - if (!mayiuse(cpu::sse42)) - return false; - if (mode == InterpolateMode::linear || mode == InterpolateMode::cubic) + if (!mayiuse(cpu::sse42) || mode == InterpolateMode::linear) { return false; + } if (node->getType() == Quantize) { auto* quantizeNode = dynamic_cast(node.get()); @@ -2858,10 +2856,9 @@ bool MKLDNNInterpolateNode::canFuse(const MKLDNNNodePtr& node) const { auto* eltwiseNode = dynamic_cast(node.get()); if (eltwiseNode == nullptr) THROW_IE_EXCEPTION << "Cannot get eltwise node " << node->getName(); - return isOneOf(eltwiseNode->getOpType(), {MulAdd, Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, + return isOneOf(eltwiseNode->getOpType(), {Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}) || - ((eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2) || - (eltwiseNode->getOpType() == Prelu)); + (eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2); } return false; diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_resample_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_resample_node.cpp deleted file mode 100644 index 7ae7ce809c8..00000000000 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_resample_node.cpp +++ /dev/null @@ -1,922 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "mkldnn_resample_node.h" -#include "desc_iterator.hpp" -#include "mkldnn_quantize_node.h" -#include -#include "mkldnn_eltwise_node.h" -#include -#include -#include -#include -#include -#include "utils/bfloat16.hpp" -#include -#include "ie_parallel.hpp" -#include - -#include "jit_generator.hpp" -#include "jit_uni_eltwise.hpp" -#include "jit_uni_depthwise.hpp" -#include "jit_uni_quantization.hpp" -#include "common/cpu_memcpy.h" - -using namespace mkldnn; -using namespace MKLDNNPlugin; -using namespace InferenceEngine; -using namespace mkldnn::impl; -using namespace mkldnn::impl::cpu; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - - -#define GET_OFF(field) offsetof(jit_resample_call_args, field) - -static inline bool isFloatCompatible(Precision prc) { - return Precision::FP32 == prc || Precision::BF16 == prc; -} - -static inline bool isFloatCompatible(memory::data_type type) { - return memory::f32 == type || memory::bf16 == type; -} - -template -struct jit_uni_resample_nearest_kernel_f32 : public jit_uni_resample_nearest_kernel, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_resample_nearest_kernel_f32) - - explicit jit_uni_resample_nearest_kernel_f32(jit_resample_config_params jcp, const mkldnn_primitive_attr &attr) - : jit_uni_resample_nearest_kernel(jcp, attr), jit_generator() { - const auto &p = attr_.post_ops_; - for (int i = 0; i < p.len_; i++) { - auto &post_op = p.entry_[i]; - if (post_op.is_eltwise()) { - eltwise_injectors.push_back(std::make_shared>( - this, - post_op.eltwise.alg, - post_op.eltwise.alpha, - post_op.eltwise.beta)); - } else if (post_op.is_depthwise()) { - depthwise_injectors.push_back(std::make_shared>( - this, - post_op.depthwise.alg)); - } else if (post_op.is_quantization()) { - quantization_injectors.push_back(std::make_shared>( - this, post_op, vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias)); - } - } - - this->preamble(); - - mov(reg_src, ptr[reg_params + GET_OFF(src)]); - mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); - mov(reg_index, ptr[reg_params + GET_OFF(index)]); - mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - mov(reg_src_stride, ptr[reg_params + GET_OFF(src_stride)]); - mov(reg_index_stride, ptr[reg_params + GET_OFF(index_stride)]); - mov(reg_dst_stride, ptr[reg_params + GET_OFF(dst_stride)]); - if (attr_.post_ops_.len_ != 0) - mov(reg_oc_off, ptr[reg_params + GET_OFF(oc_off)]); - - if (isa == cpu::avx512_common) - uni_vpxor(vmm_zero, vmm_zero, vmm_zero); - - int blk_size = jcp_.src_dt == memory::bf16 ? 16 : (vlen / sizeof(float)); - if (isa == cpu::sse42) - blk_size *= 2; - - Xbyak::Label resample_nearest_loop_label; - Xbyak::Label resample_nearest_loop_end_label; - L(resample_nearest_loop_label); - { - cmp(reg_work_amount, 0); - jle(resample_nearest_loop_end_label, T_NEAR); - - if (jcp_.planar_layout) { - uni_vmovdqu(vmm_index, ptr[reg_index]); - uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); - vgatherdps(vmm_val, ptr[reg_src + vmm_index * jcp.src_data_size], vmm_mask); - store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt); - - add(reg_dst, reg_dst_stride); - add(reg_index, reg_index_stride); - sub(reg_work_amount, 1); - } else if (jcp_.nhwc_format) { // support int8 and fusion for this format - load_vector(vmm_val, ptr[reg_src], jcp_.src_dt); - if (attr_.post_ops_.len_ != 0) - apply_post_ops(jcp_.dst_dt); - store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt); - - if (isa == cpu::sse42) { - int sse42_offset = 4; - load_vector(vmm_val, ptr[reg_src + sse42_offset * jcp_.src_data_size], jcp_.src_dt); - if (attr_.post_ops_.len_ != 0) { - add(reg_oc_off, sse42_offset * sizeof(float)); - apply_post_ops(jcp_.dst_dt); - sub(reg_oc_off, sse42_offset * sizeof(float)); - } - store_vector(ptr[reg_dst + sse42_offset * jcp_.dst_data_size], vmm_val, jcp_.dst_dt); - } - - add(reg_dst, reg_dst_stride); - add(reg_src, reg_src_stride); - add(reg_oc_off, blk_size * sizeof(float)); - sub(reg_work_amount, 1); - } else { // for blk - mov(reg_src_aux, reg_src); - mov(reg_index_oc, dword[reg_index]); - add(reg_src_aux, reg_index_oc); - - load_vector(vmm_val, ptr[reg_src_aux], jcp_.src_dt); - if (attr_.post_ops_.len_ != 0) - apply_post_ops(jcp_.dst_dt); - store_vector(ptr[reg_dst], vmm_val, jcp_.dst_dt); - - if (isa == cpu::sse42) { - int sse42_offset = 4; - add(reg_src_aux, sse42_offset * jcp_.src_data_size); - load_vector(vmm_val, ptr[reg_src_aux], jcp_.src_dt); - if (attr_.post_ops_.len_ != 0) { - add(reg_oc_off, sse42_offset * sizeof(float)); - apply_post_ops(jcp_.dst_dt); - sub(reg_oc_off, sse42_offset * sizeof(float)); - } - store_vector(ptr[reg_dst + sse42_offset * jcp_.dst_data_size], vmm_val, jcp_.dst_dt); - } - - add(reg_dst, reg_dst_stride); - add(reg_index, reg_index_stride); - sub(reg_work_amount, 1); - } - - jmp(resample_nearest_loop_label, T_NEAR); - } - L(resample_nearest_loop_end_label); - - this->postamble(); - - for (auto& inj : eltwise_injectors) - inj->prepare_table(); - - ker_ = (decltype(ker_)) this->getCode(); - } - -private: - using Vmm = typename conditional3::type; - - const int vlen = cpu_isa_traits::vlen; - - Xbyak::Reg64 reg_src = r8; - Xbyak::Reg64 reg_dst = r9; - Xbyak::Reg64 reg_src_stride = r10; - Xbyak::Reg64 reg_dst_stride = r11; - Xbyak::Reg64 reg_index_stride = r12; - Xbyak::Reg64 reg_work_amount = r13; - Xbyak::Reg64 reg_index = r14; - Xbyak::Reg64 reg_src_aux = r15; - Xbyak::Reg64 reg_params = abi_param1; - - Xbyak::Reg64 reg_oc_off = rax; - Xbyak::Reg64 reg_d_weights = rbx; - Xbyak::Reg64 reg_d_bias = rcx; - Xbyak::Reg32 reg_index_oc = edx; - - Vmm vmm_val = Vmm(0); - Vmm vmm_index = Vmm(1); - Vmm vmm_zero = Vmm(2); - Vmm vmm_mask = Vmm(3); - Vmm vmm_d_weights = Vmm(4); - Vmm vmm_d_bias = Vmm(5); - - std::vector>> eltwise_injectors; - std::vector>> depthwise_injectors; - std::vector>> quantization_injectors; - - inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) { - switch (src_dt) { - case memory::f32: - case memory::s32: - uni_vmovups(vmm_src, op); - break; - case memory::s8: - uni_vpmovsxbd(vmm_src, op); - break; - case memory::u8: - uni_vpmovzxbd(vmm_src, op); - break; - case memory::bf16: - uni_vpmovzxwd(vmm_src, op); - uni_vpslld(vmm_src, vmm_src, 16); - break; - default: - assert(!"unknown dst_dt"); - } - - if (!isFloatCompatible(src_dt)) - uni_vcvtdq2ps(vmm_src, vmm_src); - } - - inline void store_vector(const Xbyak::Address &op, Vmm vmm_dst, memory::data_type dst_dt) { - Ymm ymm_dst = Ymm(vmm_dst.getIdx()); - Xmm xmm_dst = Xmm(vmm_dst.getIdx()); - - if (dst_dt == memory::f32) { - uni_vmovups(op, vmm_dst); - } else if (dst_dt == memory::bf16) { - vcvtneps2bf16(ymm_dst, vmm_dst); - vmovdqu16(op, ymm_dst); - } else if (dst_dt == memory::u8) { - uni_vcvtps2dq(vmm_dst, vmm_dst); - if (isa == cpu::avx512_common) { - vpmaxsd(vmm_dst, vmm_dst, vmm_zero); - vpmovusdb(op, vmm_dst); - } else { - uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != cpu::sse42) - vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst); - if (isa != cpu::sse42) - vmovq(op, xmm_dst); - else - movd(op, xmm_dst); - } - } else if (dst_dt == memory::s8) { - uni_vcvtps2dq(vmm_dst, vmm_dst); - if (isa == cpu::avx512_common) { - vpmovsdb(op, vmm_dst); - } else { - uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != cpu::sse42) - vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst); - if (isa != cpu::sse42) - vmovq(op, xmm_dst); - else - movd(op, xmm_dst); - } - } - } - - void apply_post_ops(memory::data_type dst_dt) { - const auto &p = attr_.post_ops_; - int eltwise_inj_idx = 0; - int depthwise_inj_idx = 0; - int quantization_inj_idx = 0; - for (int i = 0; i < p.len_; i++) { - auto& post_op = p.entry_[i]; - if (post_op.is_eltwise()) { - eltwise_injectors[eltwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1); - eltwise_inj_idx++; - } else if (post_op.is_depthwise()) { - mov(reg_d_weights, reinterpret_cast(post_op.depthwise.weights_data)); - mov(reg_d_bias, reinterpret_cast(post_op.depthwise.biases_data)); - add(reg_d_weights, reg_oc_off); - add(reg_d_bias, reg_oc_off); - depthwise_injectors[depthwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1, reg_d_weights, reg_d_bias); - depthwise_inj_idx++; - } else if (post_op.is_quantization()) { - bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize; - bool do_rounding = do_dequantization || isFloatCompatible(dst_dt) || i != p.len_ - 1; - int s_idx = vmm_val.getIdx(); - - quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_oc_off); - quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0); - - quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_oc_off); - quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding); - - quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_oc_off); - quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0); - - quantization_inj_idx++; - } - } - } -}; - - -MKLDNNResampleNode::MKLDNNResampleNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) - : MKLDNNNode(layer, eng, cache) {} - -void MKLDNNResampleNode::getSupportedDescriptors() { - if (!descs.empty()) - return; - - if (getParentEdges().size() != 1) - THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName(); - if (getChildEdges().empty()) - THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName(); - - auto *layer = getCnnLayer().get(); - type = layer->GetParamAsString("type"); - antialias = layer->GetParamAsBool("antialias", false); - factor = layer->GetParamAsFloat("factor"); -} - -void MKLDNNResampleNode::initSupportedPrimitiveDescriptors() { - if (!supportedPrimitiveDescriptors.empty()) - return; - - if (getParentEdgeAt(0)->getDims().ndims() < 4 || getParentEdgeAt(0)->getDims().ndims() > 5) { - return; - } - - setPostOps(attr, true); - - Precision inputPrecision = getCnnLayer()->insData[0].lock()->getPrecision(); - Precision outputPrecision = getCnnLayer()->outData[0]->getPrecision(); - - if (!fusedWith.empty()) { - auto lastFusedLayer = fusedWith[fusedWith.size() - 1].get()->getCnnLayer(); - if (lastFusedLayer) { - outputPrecision = lastFusedLayer->outData[0]->getPrecision(); - } - } - - if (inputPrecision == Precision::BF16 || outputPrecision == Precision::BF16) { - if (!mayiuse(avx512_core_bf16)) - inputPrecision = outputPrecision = Precision::FP32; - else - inputPrecision = outputPrecision = Precision::BF16; - } - - auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(inputPrecision); - auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(outputPrecision); - - input_prec = inputPrecision; - output_prec = outputPrecision; - src_data_size = MKLDNNExtensionUtils::sizeOfDataType(inputDataType); - dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(outputDataType); - - InferenceEngine::LayerConfig config; - config.dynBatchSupport = false; - config.inConfs.resize(1); - config.outConfs.resize(1); - config.inConfs[0].constant = false; - config.outConfs[0].constant = false; - config.inConfs[0].inPlace = -1; - config.outConfs[0].inPlace = -1; - - auto pushDesc = [&](memory::format format) { - config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, format); - config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, format); - supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown, format}); - }; - - if (type == "caffe.ResampleParameter.NEAREST") { - if (getParentEdgeAt(0)->getDims().ndims() == 4) { - pushDesc(memory::nhwc); - } else if (getParentEdgeAt(0)->getDims().ndims() == 5) { - pushDesc(memory::ndhwc); - } - - if (isFloatCompatible(inputPrecision) && isFloatCompatible(outputPrecision)) { - if (getParentEdgeAt(0)->getDims().ndims() == 4) { - if (mayiuse(cpu::avx512_common)) { - pushDesc(memory::nChw16c); - } else if (mayiuse(cpu::avx2) || mayiuse(cpu::sse42)) { - pushDesc(memory::nChw8c); - } - } else if (getParentEdgeAt(0)->getDims().ndims() == 5) { - if (mayiuse(cpu::avx512_common)) { - pushDesc(memory::nCdhw16c); - } else if (mayiuse(cpu::avx2) || mayiuse(cpu::sse42)) { - pushDesc(memory::nCdhw8c); - } - } - - if (fusedWith.empty()) { - pushDesc(MKLDNNMemory::GetPlainFormat(getChildEdgeAt(0)->getDims())); - } - } - } - if (type == "caffe.ResampleParameter.LINEAR") { - if (getParentEdgeAt(0)->getDims().ndims() == 4) { - pushDesc(memory::nchw); - } else if (getParentEdgeAt(0)->getDims().ndims() == 5) { - pushDesc(memory::ncdhw); - } - } -} - -void MKLDNNResampleNode::createPrimitive() { - auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); - auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr(); - if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr()) - THROW_IE_EXCEPTION << "Destination memory didn't allocate."; - if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr()) - THROW_IE_EXCEPTION << "Input memory didn't allocate."; - if (getSelectedPrimitiveDescriptor() == nullptr) - THROW_IE_EXCEPTION << "Preferable primitive descriptor is not set."; - - auto selectedPD = getSelectedPrimitiveDescriptor(); - Layout selected_layout = selectedPD->getConfig().inConfs[0].desc.getLayout(); - auto jcp = jit_resample_config_params(); - jcp.src_dt = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().inConfs[0].desc.getPrecision()); - jcp.dst_dt = MKLDNNExtensionUtils::IEPrecisionToDataType(selectedPD->getConfig().outConfs[0].desc.getPrecision()); - jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(jcp.src_dt); - jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(jcp.dst_dt); - jcp.planar_layout = MKLDNNMemory::GetPlainLayout(getChildEdgeAt(0)->getDims()) == selected_layout; - jcp.nhwc_format = (selected_layout == NHWC) || (selected_layout == NDHWC); - - if (type == "caffe.ResampleParameter.NEAREST") { - if (mayiuse(cpu::avx512_common)) { - if (jcp.planar_layout) { - resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32(jcp, *attr.get())); - blk_size = 8; - } else { - resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32(jcp, *attr.get())); - blk_size = 16; - } - } else if (mayiuse(cpu::avx2)) { - resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32(jcp, *attr.get())); - blk_size = 8; - } else if (mayiuse(cpu::sse42) && !jcp.planar_layout) { - resample_nearest_kernel.reset(new jit_uni_resample_nearest_kernel_f32(jcp, *attr.get())); - blk_size = 8; - } - } -} - -void MKLDNNResampleNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) { - int blob_idx = 0; - mkldnn::post_ops ops; - - for (auto &node : fusedWith) { - auto* quantizeNode = dynamic_cast(node.get()); - if (quantizeNode) { - quantizeNode->appendPostOps(ops); - continue; - } - - auto* eltwiseNode = dynamic_cast(node.get()); - if (eltwiseNode) { - eltwiseNode->appendPostOps(ops); - continue; - } - - THROW_IE_EXCEPTION << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented"; - } - - attr.set_post_ops(ops); -} - - -void MKLDNNResampleNode::execute(mkldnn::stream strm) { - auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); - auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr(); - - Layout layout = getParentEdgeAt(0)->getDesc().getLayout(); - - SizeVector src_dim = getParentEdgeAt(0)->getDesc().getDims(); - SizeVector dst_dim = getChildEdgeAt(0)->getDesc().getDims(); - - size_t dims_size = src_dim.size(); - size_t N = src_dim[0]; - size_t C = src_dim[1]; - size_t ID = (dims_size == 5) ? src_dim[dims_size - 3] : 1lu; - size_t IH = src_dim[dims_size - 2]; - size_t IW = src_dim[dims_size - 1]; - - size_t OD = (dims_size == 5) ? dst_dim[dims_size - 3] : 1lu; - size_t OH = dst_dim[dims_size - 2]; - size_t OW = dst_dim[dims_size - 1]; - - float fx = static_cast(IW) / static_cast(OW); - float fy = static_cast(IH) / static_cast(OH); - float fz = static_cast(ID) / static_cast(OD); - - if (type == "caffe.ResampleParameter.NEAREST") { - if (layout == NCHW || layout == NCDHW) { - if (output_prec == Precision::FP32) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - NearestNeighbor_PLN(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else if (output_prec == Precision::BF16) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - NearestNeighbor_PLN(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else { - THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name(); - } - } else { - if (output_prec == Precision::U8) { - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - if (input_prec == Precision::U8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else if (input_prec == Precision::I8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else if (input_prec == Precision::FP32) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else { - THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name(); - } - } else if (output_prec == Precision::I8) { - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - if (input_prec == Precision::U8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else if (input_prec == Precision::I8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else if (input_prec == Precision::FP32) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else { - THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name(); - } - } else if (output_prec == Precision::FP32) { - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - if (input_prec == Precision::U8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else if (input_prec == Precision::I8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else if (input_prec == Precision::FP32) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else { - THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name(); - } - } else if (output_prec == Precision::BF16) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - NearestNeighbor_BLK(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW); - } else { - THROW_IE_EXCEPTION << "Unsupported output precision: " << output_prec.name(); - } - } - } else if (type == "caffe.ResampleParameter.LINEAR") { - // currently no fusion, the input and output precision is the same - bool isDownsample = (fx > 1) || (fy > 1) || (fz > 1); - int kernel_width = 2; - if (input_prec == Precision::U8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - LinearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias); - } else if (input_prec == Precision::I8) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - LinearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias); - } else if (input_prec == Precision::FP32) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - LinearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias); - } else if (input_prec == Precision::BF16) { - auto src_data = reinterpret_cast(srcMemPtr->GetData()); - auto dst_data = reinterpret_cast(dstMemPtr->GetData()); - LinearInterpolation(src_data, dst_data, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, - isDownsample && antialias); - } else { - THROW_IE_EXCEPTION << "Unsupported input precision: " << input_prec.name(); - } - } else { - THROW_IE_EXCEPTION << "Unsupported resample parameter type: " << type; - } -} - -// f32 and no fused, f32->input is f32, no fuse->output is f32 -template -void MKLDNNResampleNode::NearestNeighbor_PLN(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW) { - std::vector index_buffer(OD * OH * OW); - for (int oz = 0; oz < OD; oz++) { - float iz = oz * fz; - int iz_offset = static_cast(std::floor(iz)) * IH * IW; - int oz_offset = oz * OH * OW; - for (int oy = 0; oy < OH; oy++) { - float iy = oy * fy; - int iy_offset = static_cast(std::floor(iy)) * IW + iz_offset; - int oy_offset = oy * OW + oz_offset; - for (int ox = 0; ox < OW; ox++) { - float ix = ox * fx; - int ix_index = static_cast(std::floor(ix)) + iy_offset; - index_buffer[oy_offset + ox] = ix_index; - } - } - } - if (resample_nearest_kernel) { - parallel_for2d(B, C, [&](size_t b, size_t c) { - const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * C * b + IW * IH * ID * c; - out_data_t *out_ptr = out_ptr_ + OW * OH * OD * C * b + OW * OH * OD * c; - - // for OW*OH*OD - auto arg = jit_resample_call_args(); - arg.src = in_ptr; - arg.dst = out_ptr; - arg.index = static_cast(&index_buffer[0]); - arg.index_stride = blk_size * sizeof(int); - arg.dst_stride = blk_size * dst_data_size; - arg.work_amount = OW * OH * OD / blk_size; - (*resample_nearest_kernel)(&arg); - - int tail_start = (OW * OH * OD / blk_size) * blk_size; - for (int tail = tail_start; tail < OW * OH * OD; tail++) { - out_ptr[tail] = in_ptr[index_buffer[tail]]; - } - }); - } else { - parallel_for2d(B, C, [&](size_t b, size_t c) { - const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * C * b + IW * IH * ID * c; - out_data_t *out_ptr = out_ptr_ + OW * OH * OD * C * b + OW * OH * OD * c; - - for (int i_dst = 0; i_dst < OW * OH * OD; i_dst++) { - out_ptr[i_dst] = in_ptr[index_buffer[i_dst]]; - } - }); - } -} - -// for ndhwc and nCdhw8/16d -// int8->input may be int8, fused->output may be int8 -template -void MKLDNNResampleNode::NearestNeighbor_BLK(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW) { - std::vector index_d(OD); - std::vector index_h(OH); - std::vector index_w(OW); - for (int oz = 0; oz < OD; oz++) { - float iz = oz * fz; - index_d[oz] = static_cast(std::floor(iz)); - } - for (int oy = 0; oy < OH; oy++) { - float iy = oy * fy; - index_h[oy] = static_cast(std::floor(iy)); - } - for (int ox = 0; ox < OW; ox++) { - float ix = ox * fx; - index_w[ox] = static_cast(std::floor(ix)); - } - - Layout layout = getParentEdgeAt(0)->getDesc().getLayout(); - bool is_nhwc = (layout == NHWC || layout == NDHWC) ? true : false; - - for (int b = 0; b < B; b++) { - if (is_nhwc) { - const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * C * b; - out_data_t *out_ptr = out_ptr_ + OW * OH * OD * C * b; - if (resample_nearest_kernel) { - int tail = (C / blk_size) * blk_size; - parallel_for2d(OD, OH, [&](size_t d, size_t h) { - // better that same core process continuous memory - out_data_t *out_ptr_dh = out_ptr + C * OW * OH * d + C * OW * h; - const in_data_t *in_ptr_dh = in_ptr + C * IW * IH * index_d[d] + C * IW * index_h[h]; - auto arg = jit_resample_call_args(); - for (int ox = 0; ox < OW; ox++) { - // kernel for OC - arg.dst = out_ptr_dh + C * ox; - arg.src = in_ptr_dh + C * index_w[ox]; - arg.dst_stride = blk_size * sizeof(out_data_t); - arg.src_stride = blk_size * sizeof(in_data_t); - arg.work_amount = C / blk_size; - arg.oc_off = 0; - (*resample_nearest_kernel)(&arg); - } - // tail - if (tail != C) { - for (int ox = 0; ox < OW; ox++) { - out_data_t *out_ptr_dhw = out_ptr_dh + C * ox; - const in_data_t *in_ptr_dhw = in_ptr_dh + C * index_w[ox]; - if (fusedWith.empty() && output_prec == input_prec) { - cpu_memcpy(out_ptr_dhw + tail, in_ptr_dhw + tail, (C - tail) * sizeof(in_data_t)); - } else { - for (int c = tail; c < C; c++) { - float dst_value = static_cast(in_ptr_dhw[c]); - apply_post_ops_scalar(dst_value, c); - if (isFloatCompatible(output_prec)) { - out_ptr_dhw[c] = dst_value; - } else if (output_prec == Precision::U8) { - out_ptr_dhw[c] = (dst_value >= 0) ? lroundf(dst_value) : 0; - } else if (output_prec == Precision::I8) { - out_ptr_dhw[c] = lroundf(dst_value); - } - } - } - } - } - }); - } else { // without kernel - parallel_for2d(OD, OH, [&](size_t d, size_t h) { - out_data_t *out_ptr_dh = out_ptr + C * OW * OH * d + C * OW * h; - const in_data_t *in_ptr_dh = in_ptr + C * IW * IH * index_d[d] + C * IW * index_h[h]; - for (int ox = 0; ox < OW; ox++) { - out_data_t *out_ptr_dhw = out_ptr_dh + C * ox; - const in_data_t *in_ptr_dhw = in_ptr_dh + C * index_w[ox]; - if (fusedWith.empty() && output_prec == input_prec) { - cpu_memcpy(out_ptr_dhw, in_ptr_dhw, C * sizeof(in_data_t)); - } else { - for (int c = 0; c < C; c++) { - float dst_value = static_cast(in_ptr_dhw[c]); - apply_post_ops_scalar(dst_value, c); - if (isFloatCompatible(output_prec)) { - out_ptr_dhw[c] = dst_value; - } else if (output_prec == Precision::U8) { - out_ptr_dhw[c] = (dst_value >= 0) ? lroundf(dst_value) : 0; - } else if (output_prec == Precision::I8) { - out_ptr_dhw[c] = lroundf(dst_value); - } - } - } - } - }); - } - } else { // for nC(d)hw8/16c - int CB = div_up(C, blk_size); - const in_data_t *in_ptr = in_ptr_ + IW * IH * ID * CB * blk_size * b; - out_data_t *out_ptr = out_ptr_ + OW * OH * OD * CB * blk_size * b; - if (resample_nearest_kernel) { - std::vector index_w_kernel(OW); - for (int ox = 0; ox < OW; ox++) { - index_w_kernel[ox] = index_w[ox] * blk_size * sizeof(in_data_t); - } - parallel_for2d(CB, OD, [&](size_t cb, size_t d) { - out_data_t *out_ptr_cbd = out_ptr + blk_size * OW * OH * OD * cb + blk_size * OW * OH * d; - const in_data_t *in_ptr_cbd = in_ptr + blk_size * IW * IH * ID * cb + blk_size * IW * IH * index_d[d]; - auto arg = jit_resample_call_args(); - for (int h = 0; h < OH; h++) { // kernel for blk_size * OW - arg.dst = out_ptr_cbd + blk_size * OW * h; - arg.src = in_ptr_cbd + blk_size * IW * index_h[h]; - arg.index = static_cast(&(index_w_kernel[0])); - arg.dst_stride = static_cast(blk_size * sizeof(out_data_t)); - arg.index_stride = static_cast(1 * sizeof(int)); - arg.work_amount = static_cast(OW); - arg.oc_off = cb * blk_size; - (*resample_nearest_kernel)(&arg); - } - }); - } else { - parallel_for2d(CB, OD, [&](int cb, int d) { - out_data_t *out_ptr_cbd = out_ptr + blk_size * OW * OH * OD * cb + blk_size * OW * OH * d; - const in_data_t *in_ptr_cbd = in_ptr + blk_size * IW * IH * ID * cb + blk_size * IW * IH * index_d[d]; - for (int h = 0; h < OH; h++) { - out_data_t *out_ptr_cbdh = out_ptr_cbd + blk_size * OW * h; - const in_data_t *in_ptr_cbdh = in_ptr_cbd + blk_size * IW * index_h[h]; - for (int w = 0; w < OW; w++) { - out_data_t *out_ptr_cbdhw = out_ptr_cbdh + blk_size * w; - const in_data_t *in_ptr_cbdhw = in_ptr_cbdh + blk_size * index_w[w]; - if (fusedWith.empty()) { - cpu_memcpy(out_ptr_cbdhw, in_ptr_cbdhw, blk_size * sizeof(in_data_t)); - } else { - for (int blk = 0; blk < blk_size; blk++) { - float dst_value = static_cast(in_ptr_cbdhw[blk]); - apply_post_ops_scalar(dst_value, cb * blk_size + blk); - if (isFloatCompatible(output_prec)) { - out_ptr_cbdhw[blk] = dst_value; - } else if (output_prec == Precision::U8) { - out_ptr_cbdhw[blk] = (dst_value >= 0) ? lroundf(dst_value) : 0; - } else if (output_prec == Precision::I8) { - out_ptr_cbdhw[blk] = lroundf(dst_value); - } - } - } - } - } - }); - } - } - } // batch end -} - -static inline float triangleCoeff(float x) { - return (std::max)(0.0f, 1 - std::abs(x)); -} - -template -void MKLDNNResampleNode::LinearInterpolation(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias) { - if (IW == OW && IH == OH && ID == OD) { - size_t size = B * C * ID * IH * IW; - if (isFloatCompatible(input_prec)) { - size *= sizeof(in_data_t); - } - cpu_memcpy(out_ptr_, in_ptr_, size); - return; - } - - for (size_t b = 0; b < B; b++) { - const in_data_t *in_ptr_n = in_ptr_ + IW * IH * ID * C * b; - out_data_t *out_ptr_n = out_ptr_ + OW * OH * OD * C * b; - for (size_t c = 0; c < C; c++) { - const in_data_t *in_ptr_nc = in_ptr_n + IW * IH * ID * c; - out_data_t *out_ptr_nc = out_ptr_n + OW * OH * OD * c; - - for (size_t oz = 0; oz < OD; oz++) { - out_data_t *out_ptr_ncd = out_ptr_nc + OW * OH * oz; - for (size_t oy = 0; oy < OH; oy++) { - out_data_t *out_ptr_ncdh = out_ptr_ncd + OW * oy; - for (size_t ox = 0; ox < OW; ox++) { - float ix = ox * fx + fx / 2.0f - 0.5f; - float iy = oy * fy + fy / 2.0f - 0.5f; - float iz = oz * fz + fz / 2.0f - 0.5f; - - int ix_r = static_cast(round(ix)); - int iy_r = static_cast(round(iy)); - int iz_r = static_cast(round(iz)); - - float sum = 0; - float wsum = 0; - - float ax = 1.0f / (antialias ? fx : 1.0f); - float ay = 1.0f / (antialias ? fy : 1.0f); - float az = 1.0f / (antialias ? fz : 1.0f); - - int rx = (fx < 1.0f) ? 2 : static_cast(ceil(static_cast(kernel_width) / ax)); - int ry = (fy < 1.0f) ? 2 : static_cast(ceil(static_cast(kernel_width) / ay)); - int rz = (fz < 1.0f) ? 2 : static_cast(ceil(static_cast(kernel_width) / az)); - - for (int z = iz_r - rz; z <= iz_r + rz; z++) { - for (int y = iy_r - ry; y <= iy_r + ry; y++) { - for (int x = ix_r - rx; x <= ix_r + rx; x++) { - bool is_continue = z < 0 || - y < 0 || - x < 0 || - z >= static_cast(ID) || - y >= static_cast(IH) || - x >= static_cast(IW); - if (is_continue) - continue; - - float dx = ix - x; - float dy = iy - y; - float dz = iz - z; - - float w = ax * triangleCoeff(ax * dx) * - ay * triangleCoeff(ay * dy) * - az * triangleCoeff(az * dz); - - sum += w * static_cast(in_ptr_nc[z * IH * IW + y * IW + x]); - wsum += w; - } - } - } - if (!wsum) { - out_ptr_ncdh[ox] = 0; - } else { - float dst_value = sum / wsum; - if (isFloatCompatible(output_prec)) { - out_ptr_ncdh[ox] = dst_value; - } else if (output_prec == Precision::U8) { - out_ptr_ncdh[ox] = (dst_value >= 0) ? lroundf(dst_value) : 0; - } else if (output_prec == Precision::I8) { - out_ptr_ncdh[ox] = lroundf(dst_value); - } - } - } - } - } - } - } -} - -inline void MKLDNNResampleNode::apply_post_ops_scalar(float &dst_value, int index_c) { - const auto &p = (*attr.get()).post_ops_; - for (int i = 0; i < p.len_; i++) { - auto &post_op = p.entry_[i]; - if (post_op.is_eltwise()) { - // only eltwise_relu supported - if (dst_value < 0) dst_value = 0; - } else if (post_op.is_depthwise()) { - // only ScaleShift supported - float scale = post_op.depthwise.weights_data[index_c]; - float shift = post_op.depthwise.biases_data[index_c]; - dst_value = dst_value * scale + shift; - } else if (post_op.is_quantization()) { - bool do_dequantization = post_op.quantization.alg == - alg_kind::quantization_quantize_dequantize; - bool do_rounding = do_dequantization || isFloatCompatible(output_prec) || - i != p.len_ - 1; - - auto quant = post_op.quantization; - - float crop_low = quant.crop_low_data->shifts_[quant.crop_low_data->count_ == 1 ? 0 : index_c]; - float crop_high = quant.crop_high_data->shifts_[quant.crop_high_data->count_ == 1 ? 0 : index_c]; - float input_scale = quant.input_scale_data->scales_[quant.input_scale_data->count_ == 1 ? 0 : index_c]; - float input_shift = quant.input_shift_data->shifts_[quant.input_shift_data->count_ == 1 ? 0 : index_c]; - - dst_value = nstl::min(crop_high, nstl::max(crop_low, dst_value)); - dst_value = dst_value * input_scale + input_shift; - - if (do_rounding) { - dst_value = roundf(dst_value); - } - - if (do_dequantization) { - float output_scale = quant.output_scale_data->scales_[quant.output_scale_data->count_ == 1 ? 0 : index_c]; - float output_shift = quant.output_shift_data->shifts_[quant.output_shift_data->count_ == 1 ? 0 : index_c]; - dst_value = dst_value * output_scale + output_shift; - } - } - } -} - -bool MKLDNNResampleNode::created() const { - return getType() == Resample; -} - -REG_MKLDNN_PRIM_FOR(MKLDNNResampleNode, Resample); \ No newline at end of file diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_resample_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_resample_node.h deleted file mode 100644 index 47137a0dfef..00000000000 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_resample_node.h +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include -#include -#include - -namespace MKLDNNPlugin { - -struct jit_resample_config_params { - bool planar_layout; - bool nhwc_format; - mkldnn::memory::data_type src_dt; - mkldnn::memory::data_type dst_dt; - int src_data_size; - int dst_data_size; -}; - -struct jit_resample_call_args { - const void *src; - const int *index; - void *dst; - size_t src_stride; - size_t index_stride; - size_t dst_stride; - size_t work_amount; - size_t oc_off; -}; - -struct jit_uni_resample_nearest_kernel { - void (*ker_)(const jit_resample_call_args *); - - void operator()(const jit_resample_call_args *args) { - assert(ker_); - ker_(args); - } - - explicit jit_uni_resample_nearest_kernel(jit_resample_config_params jcp, const mkldnn_primitive_attr &attr) : ker_(nullptr), jcp_(jcp), attr_(attr) {} - virtual ~jit_uni_resample_nearest_kernel() {} - - jit_resample_config_params jcp_; - const mkldnn_primitive_attr &attr_; -}; - -struct jit_uni_resample_linear_kernel { - void (*ker_)(const jit_resample_call_args *); - - void operator()(const jit_resample_call_args *args) { - assert(ker_); - ker_(args); - } - - explicit jit_uni_resample_linear_kernel(jit_resample_config_params jcp, const mkldnn_primitive_attr &attr) : ker_(nullptr), jcp_(jcp), attr_(attr) {} - virtual ~jit_uni_resample_linear_kernel() {} - - jit_resample_config_params jcp_; - const mkldnn_primitive_attr &attr_; -}; - - -class MKLDNNResampleNode : public MKLDNNNode { -public: - MKLDNNResampleNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache); - ~MKLDNNResampleNode() override = default; - - void getSupportedDescriptors() override; - void initSupportedPrimitiveDescriptors() override; - void createPrimitive() override; - bool created() const override; - void execute(mkldnn::stream strm) override; - bool canBeInPlace() const override { - return false; - } - -private: - template - void NearestNeighbor_PLN(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW); - template - void NearestNeighbor_BLK(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW); - template - void LinearInterpolation(const in_data_t *in_ptr_, out_data_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias); - void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false); - inline void apply_post_ops_scalar(float &dst_value, int index_c); - - int blk_size; - - std::string type; - bool antialias; - float factor; - - mkldnn::primitive_attr attr; - std::vector PostOpsIntBlobMemory; - - InferenceEngine::Precision input_prec, output_prec; - size_t src_data_size, dst_data_size; - - std::shared_ptr resample_nearest_kernel; -}; - -} // namespace MKLDNNPlugin - diff --git a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp index 9d153429994..a1cebccf2a6 100644 --- a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp +++ b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/interpolate.cpp @@ -55,6 +55,7 @@ protected: std::vector axes; std::vector scales; std:tie(mode, shapeCalcMode, coordinateTransformMode, nearestMode, antialias, padBegin, padEnd, cubeCoef, axes, scales) = interpolateParams; + inPrc = outPrc = netPrecision; using ShapeCalcMode = ngraph::op::v4::Interpolate::ShapeCalcMode; @@ -81,6 +82,8 @@ protected: interpolate->get_rt_info() = getCPUInfo(); const ngraph::ResultVector results{std::make_shared(interpolate)}; function = std::make_shared(results, params, "interpolate"); + + selectedType = getPrimitiveType() + "_" + inPrc.name(); } }; @@ -99,7 +102,6 @@ std::vector filterCPUInfoForDevice() { if (with_cpu_x86_avx512f()) { resCPUParams.push_back(CPUSpecificParams{{nChw16c, x, x}, {nChw16c}, {"jit_avx512"}, "jit_avx512_FP32"}); resCPUParams.push_back(CPUSpecificParams{{nhwc, x, x}, {nhwc}, {"jit_avx512"}, "jit_avx512_FP32"}); - resCPUParams.push_back(CPUSpecificParams{{nchw, x, x}, {nchw}, {"jit_avx2"}, "jit_avx2_FP32"}); } else if (with_cpu_x86_avx2()) { resCPUParams.push_back(CPUSpecificParams{{nChw8c, x, x}, {nChw8c}, {"jit_avx2"}, "jit_avx2_FP32"}); resCPUParams.push_back(CPUSpecificParams{{nhwc, x, x}, {nhwc}, {"jit_avx2"}, "jit_avx2_FP32"}); @@ -115,7 +117,8 @@ std::vector filterCPUInfoForDevice() { /* ========== */ const std::vector netPrecisions = { - InferenceEngine::Precision::FP32 + InferenceEngine::Precision::FP32, + InferenceEngine::Precision::BF16 }; const std::vector coordinateTransformModes = { diff --git a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/extensions/interp_tests.cpp b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/extensions/interp_tests.cpp deleted file mode 100644 index e7e900a7d01..00000000000 --- a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/extensions/interp_tests.cpp +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - - -#include "test_graph.hpp" - -#include "single_layer_common.hpp" -#include "tests_common.hpp" -#include - - -using namespace ::testing; -using namespace std; -using namespace mkldnn; - - -struct interp_test_params { - struct { - size_t n; - size_t c; - size_t h; - size_t w; - } in; - - struct { - size_t h; - size_t w; - } out; - - int pad_beg; - int pad_end; - - size_t num_prim_desc; - - int selectedType; - - std::vector> comp; -}; - -void interpolate(const int N, const int C, const float *src, const int x1, const int y1, const int IH_pad, const int IW_pad, - const int IH, const int IW, float *dst, const int x2, const int y2, const int OH_pad, const int OW_pad, const int OH, const int OW) { - if (IH_pad == OH_pad && IW_pad == OW_pad) { - for (int i = 0; i < N * C * OH * OW; i++) { - dst[i] = src[i]; - } - return; - } - - const float rh = (OH_pad > 1) ? static_cast(IH_pad - 1) / (OH_pad - 1) : 0.0f; - const float rw = (OW_pad > 1) ? static_cast(IW_pad - 1) / (OW_pad - 1) : 0.0f; - - const int block_size = 1; - - // Align channel number to block size to deal with channels padding in IE with multiple blobs - int CB = (C + block_size - 1) & (-block_size); // CB=n*block_size, i.e.:c=15,(block_size=8), then CB=16, CH=2 - - int CH = (C + block_size - 1) / block_size; // number of block:(n) - - for (int n = 0; n < N; n++) { - for (int cb = 0; cb < CH; ++cb) { - for (int h = 0; h < OH_pad; ++h) { - const float *psrc = src + n * CB * IH * IW; // should be nChw8c(16c) data format - - float fh = rh * h; - int ih0 = static_cast(fh); - int ih1 = (ih0 < IH_pad - 1) ? ih0 + 1 : ih0; - - float h_lambda0 = fh - ih0; - float h_lambda1 = 1.0f - h_lambda0; - - for (int w = 0; w < OW_pad; ++w) { - float fw = rw * w; - int iw0 = static_cast(fw); - int iw1 = (iw0 < IW_pad - 1) ? iw0 + 1 : iw0; - - float w_lambda0 = fw - iw0; - float w_lambda1 = 1.0f - w_lambda0; - - const float *psrc00 = - psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw0) * block_size; - const float *psrc01 = - psrc + cb * block_size * IW * IH + (y1 + ih0) * IW * block_size + (x1 + iw1) * block_size; - const float *psrc10 = - psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw0) * block_size; - const float *psrc11 = - psrc + cb * block_size * IW * IH + (y1 + ih1) * IW * block_size + (x1 + iw1) * block_size; - - float *pdst = dst + n * CB * OH * OW + cb * block_size * OW * OH + (y2 + h) * OW * block_size + - (x2 + w) * block_size; - - for (int c = 0; c < block_size; ++c) { - pdst[c] = h_lambda1 * (w_lambda1 * psrc00[c] + w_lambda0 * psrc01[c]) + - h_lambda0 * (w_lambda1 * psrc10[c] + w_lambda0 * psrc11[c]); - } - } - } - } - } -} - -template -void ref_interp(const InferenceEngine::TBlob &src, InferenceEngine::TBlob &dst, interp_test_params prm) { - int IB = static_cast(src.getTensorDesc().getDims()[0]); - int IC = static_cast(src.getTensorDesc().getDims()[1]); - int IH = static_cast(src.getTensorDesc().getDims()[2]); - int IW = static_cast(src.getTensorDesc().getDims()[3]); - - int OH = static_cast(dst.getTensorDesc().getDims()[2]); - int OW = static_cast(dst.getTensorDesc().getDims()[3]); - - int IH_pad = IH + prm.pad_beg + prm.pad_end; - int IW_pad = IW + prm.pad_beg + prm.pad_end; - - const data_t *src_data = src.readOnly(); - data_t *dst_data = dst.data(); - - interpolate(IB, IC, src_data, -prm.pad_beg, -prm.pad_beg, IH_pad, IW_pad, IH, IW, dst_data, 0, 0, OH, OW, OH, OW); -} - -class MKLDNNCPUExtInterpTests: public TestsCommon, public WithParamInterface { - std::string model_t = R"V0G0N( - - - - - - _IN_ - _IC_ - _IH_ - _IW_ - - - - - - - - - _IN_ - _IC_ - _IH_ - _IW_ - - - - - _IN_ - _IC_ - _OH_ - _OW_ - - - - - - - - -)V0G0N"; - - std::string getModel(interp_test_params p) { - std::string model = model_t; - REPLACE_WITH_NUM(model, "_IW_", p.in.w); - REPLACE_WITH_NUM(model, "_IH_", p.in.h); - REPLACE_WITH_NUM(model, "_IC_", p.in.c); - REPLACE_WITH_NUM(model, "_IN_", p.in.n); - - REPLACE_WITH_NUM(model, "_OH_", p.out.h); - REPLACE_WITH_NUM(model, "_OW_", p.out.w); - - REPLACE_WITH_NUM(model, "_PB_", p.pad_beg); - REPLACE_WITH_NUM(model, "_PE_", p.pad_end); - return model; - } - -protected: - virtual void TearDown() { - } - - virtual void SetUp() { - try { - TestsCommon::SetUp(); - interp_test_params p = ::testing::WithParamInterface::GetParam(); - std::string model = getModel(p); - - InferenceEngine::Core core; - InferenceEngine::CNNNetwork network; - ASSERT_NO_THROW(network = core.ReadNetwork(model, InferenceEngine::Blob::CPtr())); - - MKLDNNGraphTestClass graph; - graph.CreateGraph(network); - - auto& nodes = graph.getNodes(); - nodes = graph.getNodes(); - for (auto &node : nodes) { - if (node->getName() == "interp1") { - ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size()); - for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) { - p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j)); - } - ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor()); - ASSERT_EQ(p.selectedType, - node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType); - } - } - ASSERT_LE(4, nodes.size()); - - InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w}; - - InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob({InferenceEngine::Precision::FP32, dims_src, InferenceEngine::NCHW}); - src->allocate(); - fill_data(src->buffer(), src->size()); - - auto * srcPtr = dynamic_cast*>(src.get()); - - if (srcPtr == nullptr) - FAIL() << "Cannot cast blob to TBlob."; - - InferenceEngine::BlobMap srcs; - srcs.insert(std::pair("in1", src)); - - InferenceEngine::OutputsDataMap out; - out = network.getOutputsInfo(); - InferenceEngine::BlobMap outputBlobs; - - std::pair item = *out.begin(); - - InferenceEngine::TBlob::Ptr output; - output = InferenceEngine::make_shared_blob(item.second->getTensorDesc()); - output->allocate(); - outputBlobs[item.first] = output; - - graph.Infer(srcs, outputBlobs); - - - InferenceEngine::TBlob dst_ref(item.second->getTensorDesc()); - dst_ref.allocate(); - ref_interp(*srcPtr, dst_ref, p); - compare(*output, dst_ref); - } catch (const InferenceEngine::details::InferenceEngineException &e) { - FAIL() << e.what(); - } - } -}; - -TEST_P(MKLDNNCPUExtInterpTests, TestsInterp) {} - -INSTANTIATE_TEST_CASE_P( - TestsInterp, MKLDNNCPUExtInterpTests, - ::testing::Values( - interp_test_params{{1, 256, 1, 1}, {33, 65}, 0, 0, 1, MKLDNNPlugin::impl_desc_type::unknown }, - interp_test_params{{6, 128, 320, 320}, {23, 38}, 0, 0, 1, MKLDNNPlugin::impl_desc_type::unknown }, - interp_test_params{{1, 2, 33, 65}, {33, 65}, 0, 0, 1, MKLDNNPlugin::impl_desc_type::unknown })); diff --git a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/extensions/resample_tests.cpp b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/extensions/resample_tests.cpp deleted file mode 100644 index e46a14b7f31..00000000000 --- a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/extensions/resample_tests.cpp +++ /dev/null @@ -1,367 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "test_graph.hpp" - -#include "single_layer_common.hpp" -#include "tests_common.hpp" -#include "ir_gen_helper.hpp" -#include - -#include - -using namespace InferenceEngine; -using namespace ::testing; -using namespace std; -using namespace single_layer_tests; - -using namespace Extensions; -using namespace ::Cpu; - -struct resample_test_params { - std::vector in_dims; - - float factor; - int antialias; - std::string type; - - size_t num_prim_desc; - bool isBlockedFormat; - int selectedType; - - std::vector> comp; -}; - - -static inline float triangleCoeff(float x) { - return max(0.0f, 1 - std::abs(x)); -} - -extern InferenceEngine::IExtensionPtr make_FakeExtensions(); - -template -void ref_resample(const InferenceEngine::TBlob &src, InferenceEngine::TBlob &dst, resample_test_params prm) { - const data_t *src_data = src.readOnly(); - data_t *dst_data = dst.data(); - - size_t ndims = prm.in_dims.size(); - - size_t N = prm.in_dims[0]; - size_t C = prm.in_dims[1]; - size_t ID = ndims == 5 ? prm.in_dims[ndims - 3] : 1; - size_t IH = prm.in_dims[ndims - 2]; - size_t IW = prm.in_dims[ndims - 1]; - size_t OD = ndims == 5 ? ID / prm.factor : 1; - size_t OH = IH / prm.factor; - size_t OW = IW / prm.factor; - - float fx = static_cast(IW) / static_cast(OW); - float fy = static_cast(IH) / static_cast(OH); - float fz = static_cast(ID) / static_cast(OD); - - if (prm.type == "caffe.ResampleParameter.NEAREST") { - for (size_t b = 0; b < N; b++) { - for (size_t c = 0; c < C; c++) { - const float *in_ptr = src_data + IW * IH * ID * C * b + IW * IH * ID * c; - float *out_ptr = dst_data + OW * OH * OD * C * b + OW * OH * OD * c; - - for (size_t oz = 0; oz < OD; oz++) { - for (size_t oy = 0; oy < OH; oy++) { - for (size_t ox = 0; ox < OW; ox++) { - float ix = ox * fx; - float iy = oy * fy; - float iz = oz * fz; - - size_t ix_r = static_cast(std::floor(ix)); - size_t iy_r = static_cast(std::floor(iy)); - size_t iz_r = static_cast(std::floor(iz)); - - out_ptr[oz * OH * OW + oy * OW + ox] = in_ptr[iz_r * IH * IW + iy_r * IW + ix_r]; - } - } - } - } - } - } else if (prm.type == "caffe.ResampleParameter.LINEAR") { - size_t kernel_width = 2; - bool isDownsample = (fx > 1) || (fy > 1) || (fz > 1); - bool antialias = isDownsample && prm.antialias; - - for (size_t b = 0; b < N; b++) { - for (size_t c = 0; c < C; c++) { - const float *in_ptr = src_data + IW * IH * ID * C * b + IW * IH * ID * c; - float *out_ptr = dst_data + OW * OH * OD * C * b + OW * OH * OD * c; - - for (size_t oz = 0; oz < OD; oz++) { - for (size_t oy = 0; oy < OH; oy++) { - for (size_t ox = 0; ox < OW; ox++) { - float ix = ox * fx + fx / 2.0f - 0.5f; - float iy = oy * fy + fy / 2.0f - 0.5f; - float iz = oz * fz + fz / 2.0f - 0.5f; - - int ix_r = static_cast(round(ix)); - int iy_r = static_cast(round(iy)); - int iz_r = static_cast(round(iz)); - - float sum = 0; - float wsum = 0; - - float ax = 1.0f / (antialias ? fx : 1.0f); - float ay = 1.0f / (antialias ? fy : 1.0f); - float az = 1.0f / (antialias ? fz : 1.0f); - - int rx = (fx < 1.0f) ? 2 : static_cast(ceil(static_cast(kernel_width) / ax)); - int ry = (fy < 1.0f) ? 2 : static_cast(ceil(static_cast(kernel_width) / ay)); - int rz = (fz < 1.0f) ? 2 : static_cast(ceil(static_cast(kernel_width) / az)); - - for (int z = iz_r - rz; z <= iz_r + rz; z++) { - for (int y = iy_r - ry; y <= iy_r + ry; y++) { - for (int x = ix_r - rx; x <= ix_r + rx; x++) { - if (z < 0 || y < 0 || x < 0 || z >= static_cast(ID) ||y >= static_cast(IH) || x >= static_cast(IW)) - continue; - - float dx = ix - x; - float dy = iy - y; - float dz = iz - z; - - float w = ax * triangleCoeff(ax * dx) * ay * triangleCoeff(ay * dy) * az * triangleCoeff(az * dz); - - sum += w * in_ptr[z * IH * IW + y * IW + x]; - wsum += w; - } - } - } - out_ptr[oz * OH * OW + oy * OW + ox] = (!wsum) ? 0 : (sum / wsum); - } - } - } - } - } - } else { - assert(!"Unsupported resample operation type"); - } -} - -class MKLDNNCPUExtResampleTests: public TestsCommon, public WithParamInterface { - std::string model_t = R"V0G0N( - - - - - - _IN_ - _IC_ - _ID_ - _IH_ - _IW_ - - - - - - - _IN_ - _IC_ - _ID_ - _IH_ - _IW_ - - - - - _IN_ - _IC_ - _ID_ - _IH_ - _IW_ - - - - - - - - _IN_ - _IC_ - _ID_ - _IH_ - _IW_ - - - - - _IN_ - _IC_ - _OD_ - _OH_ - _OW_ - - - - - - - - - -)V0G0N"; - - std::string getModel(resample_test_params p) { - std::string model = model_t; - - auto dims_size = p.in_dims.size(); - if (dims_size == 4) { - REMOVE_LINE(model, "_ID_"); - REMOVE_LINE(model, "_OD_"); - } - - if (p.isBlockedFormat) - REPLACE_WITH_STR(model, "_FL_", "FakeLayerBLK"); - else - REPLACE_WITH_STR(model, "_FL_", "FakeLayerPLN"); - - REPLACE_WITH_NUM(model, "_IN_", p.in_dims[0]); - REPLACE_WITH_NUM(model, "_IC_", p.in_dims[1]); - if (dims_size == 5) - REPLACE_WITH_NUM(model, "_ID_", p.in_dims[dims_size - 3]); - REPLACE_WITH_NUM(model, "_IH_", p.in_dims[dims_size - 2]); - REPLACE_WITH_NUM(model, "_IW_", p.in_dims[dims_size - 1]); - - if (dims_size == 5) - REPLACE_WITH_NUM(model, "_OD_", (int)(p.in_dims[dims_size - 3] / p.factor)); - REPLACE_WITH_NUM(model, "_OH_", (int)(p.in_dims[dims_size - 2] / p.factor)); - REPLACE_WITH_NUM(model, "_OW_", (int)(p.in_dims[dims_size - 1] / p.factor)); - - REPLACE_WITH_NUM(model, "_AN_", p.antialias); - REPLACE_WITH_NUM(model, "_F_", p.factor); - REPLACE_WITH_STR(model, "_T_", p.type); - - return model; - } - -protected: - virtual void TearDown() { - } - - virtual void SetUp() { - try { - TestsCommon::SetUp(); - resample_test_params p = ::testing::WithParamInterface::GetParam(); - std::string model = getModel(p); - - MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr(new MKLDNNPlugin::MKLDNNExtensionManager()); - auto defaultExtensions = std::make_shared(); - extMgr->AddExtension(defaultExtensions); - extMgr->AddExtension(make_FakeExtensions()); - - InferenceEngine::Core core; - InferenceEngine::CNNNetwork network; - ASSERT_NO_THROW(network = core.ReadNetwork(model, InferenceEngine::Blob::CPtr())); - - MKLDNNGraphTestClass graph; - graph.CreateGraph(network, extMgr); - - auto& nodes = graph.getNodes(); - nodes = graph.getNodes(); - - for (auto &node : nodes) { - if (node->getName() == "resample") { - ASSERT_EQ(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size()); - for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) { - p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j)); - } - ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor()); - ASSERT_EQ(p.selectedType, - node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType); - } - } - - InferenceEngine::SizeVector dims_src = p.in_dims; - - InferenceEngine::Layout layout = InferenceEngine::ANY; - switch (p.in_dims.size()) { - case 4: layout = InferenceEngine::NCHW; break; - case 5: layout = InferenceEngine::NCDHW; break; - } - - InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob({InferenceEngine::Precision::FP32, dims_src, layout}); - src->allocate(); - fill_data(src->buffer(), src->size()); - - auto * srcPtr = dynamic_cast*>(src.get()); - - if (srcPtr == nullptr) - FAIL() << "Cannot cast blob to TBlob."; - - InferenceEngine::BlobMap srcs; - srcs.insert(std::pair("in1", src)); - - InferenceEngine::OutputsDataMap out; - out = network.getOutputsInfo(); - InferenceEngine::BlobMap outputBlobs; - - std::pair item = *out.begin(); - - InferenceEngine::TBlob::Ptr output; - output = InferenceEngine::make_shared_blob(item.second->getTensorDesc()); - output->allocate(); - outputBlobs[item.first] = output; - - graph.Infer(srcs, outputBlobs); - - InferenceEngine::TBlob dst_ref(item.second->getTensorDesc()); - dst_ref.allocate(); - ref_resample(*srcPtr, dst_ref, p); - compare(*output, dst_ref); - } catch (const InferenceEngine::details::InferenceEngineException &e) { - FAIL() << e.what(); - } - } -}; - -TEST_P(MKLDNNCPUExtResampleTests, TestsResample) {} - -INSTANTIATE_TEST_CASE_P( - TestsResample, MKLDNNCPUExtResampleTests, - ::testing::Values( - resample_test_params{{2, 64, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 15, 25}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 15, 25}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - // 5D nearest - resample_test_params{{2, 64, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 64, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 20, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 3, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 3, true, MKLDNNPlugin::impl_desc_type::unknown }, - // 5D linear - resample_test_params{{2, 15, 15, 10, 20}, 9.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 15, 15, 10, 20}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 15, 15, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 2, 15, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 15, 15, 10, 20}, 9.f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 15, 15, 10, 20}, 1.f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 8, 15, 10, 20}, 4.f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }, - resample_test_params{{2, 2, 15, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown })); \ No newline at end of file diff --git a/ngraph/core/src/op/interpolate.cpp b/ngraph/core/src/op/interpolate.cpp index 14b58b9381b..5395c08c123 100644 --- a/ngraph/core/src/op/interpolate.cpp +++ b/ngraph/core/src/op/interpolate.cpp @@ -222,8 +222,8 @@ void op::v4::Interpolate::validate_and_infer_types() element::Type input_et = get_input_element_type(0); NODE_VALIDATION_CHECK(this, input_et == element::Type_t::f32 || input_et == element::Type_t::f16 || - input_et == element::Type_t::i8, - "Input element type must be f32, f16, or i8"); + input_et == element::Type_t::i8 || input_et == element::Type_t::bf16, + "Input element type must be f32, f16, bf16 or i8"); PartialShape input_shape = PartialShape(get_input_partial_shape(0));