Improvements for MKLDNN Input node execution (#1913)

This commit is contained in:
Vladislav Volkov 2020-08-26 08:38:29 +03:00 committed by GitHub
parent 12197a4800
commit 07ebc2d3bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,8 @@
#include "mkldnn_input_node.h"
#include "../mkldnn_extension_utils.h"
#include <string>
#include <tuple>
#include <algorithm>
#include "caseless.hpp"
#include "ie_memcpy.h"
@ -97,6 +99,47 @@ bool MKLDNNInputNode::created() const {
return getType() == Input || getType() == Output;
}
namespace {
bool isDefaultOrder(const InferenceEngine::SizeVector &order) {
return std::is_sorted(order.begin(), order.end(),
[](size_t a, size_t b) { return a + 1 == b; });
}
std::tuple<bool, size_t> isDefaultStrides(const InferenceEngine::SizeVector &strides,
const InferenceEngine::SizeVector &dims) {
if (strides.size() != dims.size())
return std::make_tuple(false, 0);
size_t dim = 1;
for (size_t i = dims.size(); i-- > 0;) {
if (strides[i] != dim)
return std::make_tuple(false, 0);
dim *= dims[i];
}
return std::make_tuple(true, dim);
}
bool isCompatibleTensors(const InferenceEngine::TensorDesc &lhs, const InferenceEngine::TensorDesc &rhs) {
auto const &lhsBlockingDesc = lhs.getBlockingDesc();
auto const &rhsBlockingDesc = rhs.getBlockingDesc();
bool lhsDefaultStrides, rhsDefaultStrides;
size_t lhsSize, rhsSize;
std::tie(lhsDefaultStrides, lhsSize) = isDefaultStrides(lhsBlockingDesc.getStrides(), lhs.getDims());
std::tie(rhsDefaultStrides, rhsSize) = isDefaultStrides(rhsBlockingDesc.getStrides(), rhs.getDims());
return lhs.getPrecision() == rhs.getPrecision()
&& lhsSize == rhsSize
&& lhsDefaultStrides
&& rhsDefaultStrides
&& isDefaultOrder(lhsBlockingDesc.getOrder())
&& isDefaultOrder(rhsBlockingDesc.getOrder());
}
} // namespace
void MKLDNNInputNode::execute(mkldnn::stream strm) {
if (!constBlob)
return;
@ -106,7 +149,8 @@ void MKLDNNInputNode::execute(mkldnn::stream strm) {
THROW_IE_EXCEPTION << "Incorrect blob sizes for node " << getName();
}
if (constBlob->getTensorDesc() == dstBlob->getTensorDesc()) {
if (constBlob->getTensorDesc() == dstBlob->getTensorDesc()
|| isCompatibleTensors(constBlob->getTensorDesc(), dstBlob->getTensorDesc())) {
const int8_t *srcData = constBlob->cbuffer().as<int8_t *>();
int8_t *dstData = dstBlob->buffer();