From 07ebc2d3bda872069190c9cb81b291921e9c54cf Mon Sep 17 00:00:00 2001 From: Vladislav Volkov Date: Wed, 26 Aug 2020 08:38:29 +0300 Subject: [PATCH] Improvements for MKLDNN Input node execution (#1913) --- .../mkldnn_plugin/nodes/mkldnn_input_node.cpp | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_input_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_input_node.cpp index d8493c6effa..306f9812dd8 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_input_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_input_node.cpp @@ -5,6 +5,8 @@ #include "mkldnn_input_node.h" #include "../mkldnn_extension_utils.h" #include +#include +#include #include "caseless.hpp" #include "ie_memcpy.h" @@ -97,6 +99,47 @@ bool MKLDNNInputNode::created() const { return getType() == Input || getType() == Output; } +namespace { + bool isDefaultOrder(const InferenceEngine::SizeVector &order) { + return std::is_sorted(order.begin(), order.end(), + [](size_t a, size_t b) { return a + 1 == b; }); + } + + std::tuple isDefaultStrides(const InferenceEngine::SizeVector &strides, + const InferenceEngine::SizeVector &dims) { + if (strides.size() != dims.size()) + return std::make_tuple(false, 0); + + size_t dim = 1; + + for (size_t i = dims.size(); i-- > 0;) { + if (strides[i] != dim) + return std::make_tuple(false, 0); + dim *= dims[i]; + } + + return std::make_tuple(true, dim); + } + + bool isCompatibleTensors(const InferenceEngine::TensorDesc &lhs, const InferenceEngine::TensorDesc &rhs) { + auto const &lhsBlockingDesc = lhs.getBlockingDesc(); + auto const &rhsBlockingDesc = rhs.getBlockingDesc(); + + bool lhsDefaultStrides, rhsDefaultStrides; + size_t lhsSize, rhsSize; + + std::tie(lhsDefaultStrides, lhsSize) = isDefaultStrides(lhsBlockingDesc.getStrides(), lhs.getDims()); + std::tie(rhsDefaultStrides, rhsSize) = isDefaultStrides(rhsBlockingDesc.getStrides(), rhs.getDims()); + + return lhs.getPrecision() == rhs.getPrecision() + && lhsSize == rhsSize + && lhsDefaultStrides + && rhsDefaultStrides + && isDefaultOrder(lhsBlockingDesc.getOrder()) + && isDefaultOrder(rhsBlockingDesc.getOrder()); + } +} // namespace + void MKLDNNInputNode::execute(mkldnn::stream strm) { if (!constBlob) return; @@ -106,7 +149,8 @@ void MKLDNNInputNode::execute(mkldnn::stream strm) { THROW_IE_EXCEPTION << "Incorrect blob sizes for node " << getName(); } - if (constBlob->getTensorDesc() == dstBlob->getTensorDesc()) { + if (constBlob->getTensorDesc() == dstBlob->getTensorDesc() + || isCompatibleTensors(constBlob->getTensorDesc(), dstBlob->getTensorDesc())) { const int8_t *srcData = constBlob->cbuffer().as(); int8_t *dstData = dstBlob->buffer();