Improvements for MKLDNN Input node execution (#1913)
This commit is contained in:
parent
12197a4800
commit
07ebc2d3bd
@ -5,6 +5,8 @@
|
||||
#include "mkldnn_input_node.h"
|
||||
#include "../mkldnn_extension_utils.h"
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <algorithm>
|
||||
#include "caseless.hpp"
|
||||
#include "ie_memcpy.h"
|
||||
|
||||
@ -97,6 +99,47 @@ bool MKLDNNInputNode::created() const {
|
||||
return getType() == Input || getType() == Output;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool isDefaultOrder(const InferenceEngine::SizeVector &order) {
|
||||
return std::is_sorted(order.begin(), order.end(),
|
||||
[](size_t a, size_t b) { return a + 1 == b; });
|
||||
}
|
||||
|
||||
std::tuple<bool, size_t> isDefaultStrides(const InferenceEngine::SizeVector &strides,
|
||||
const InferenceEngine::SizeVector &dims) {
|
||||
if (strides.size() != dims.size())
|
||||
return std::make_tuple(false, 0);
|
||||
|
||||
size_t dim = 1;
|
||||
|
||||
for (size_t i = dims.size(); i-- > 0;) {
|
||||
if (strides[i] != dim)
|
||||
return std::make_tuple(false, 0);
|
||||
dim *= dims[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(true, dim);
|
||||
}
|
||||
|
||||
bool isCompatibleTensors(const InferenceEngine::TensorDesc &lhs, const InferenceEngine::TensorDesc &rhs) {
|
||||
auto const &lhsBlockingDesc = lhs.getBlockingDesc();
|
||||
auto const &rhsBlockingDesc = rhs.getBlockingDesc();
|
||||
|
||||
bool lhsDefaultStrides, rhsDefaultStrides;
|
||||
size_t lhsSize, rhsSize;
|
||||
|
||||
std::tie(lhsDefaultStrides, lhsSize) = isDefaultStrides(lhsBlockingDesc.getStrides(), lhs.getDims());
|
||||
std::tie(rhsDefaultStrides, rhsSize) = isDefaultStrides(rhsBlockingDesc.getStrides(), rhs.getDims());
|
||||
|
||||
return lhs.getPrecision() == rhs.getPrecision()
|
||||
&& lhsSize == rhsSize
|
||||
&& lhsDefaultStrides
|
||||
&& rhsDefaultStrides
|
||||
&& isDefaultOrder(lhsBlockingDesc.getOrder())
|
||||
&& isDefaultOrder(rhsBlockingDesc.getOrder());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MKLDNNInputNode::execute(mkldnn::stream strm) {
|
||||
if (!constBlob)
|
||||
return;
|
||||
@ -106,7 +149,8 @@ void MKLDNNInputNode::execute(mkldnn::stream strm) {
|
||||
THROW_IE_EXCEPTION << "Incorrect blob sizes for node " << getName();
|
||||
}
|
||||
|
||||
if (constBlob->getTensorDesc() == dstBlob->getTensorDesc()) {
|
||||
if (constBlob->getTensorDesc() == dstBlob->getTensorDesc()
|
||||
|| isCompatibleTensors(constBlob->getTensorDesc(), dstBlob->getTensorDesc())) {
|
||||
const int8_t *srcData = constBlob->cbuffer().as<int8_t *>();
|
||||
int8_t *dstData = dstBlob->buffer();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user