Improvements for MKLDNN Input node execution (#1913)
This commit is contained in:
parent
12197a4800
commit
07ebc2d3bd
@ -5,6 +5,8 @@
|
|||||||
#include "mkldnn_input_node.h"
|
#include "mkldnn_input_node.h"
|
||||||
#include "../mkldnn_extension_utils.h"
|
#include "../mkldnn_extension_utils.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <algorithm>
|
||||||
#include "caseless.hpp"
|
#include "caseless.hpp"
|
||||||
#include "ie_memcpy.h"
|
#include "ie_memcpy.h"
|
||||||
|
|
||||||
@ -97,6 +99,47 @@ bool MKLDNNInputNode::created() const {
|
|||||||
return getType() == Input || getType() == Output;
|
return getType() == Input || getType() == Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
bool isDefaultOrder(const InferenceEngine::SizeVector &order) {
|
||||||
|
return std::is_sorted(order.begin(), order.end(),
|
||||||
|
[](size_t a, size_t b) { return a + 1 == b; });
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, size_t> isDefaultStrides(const InferenceEngine::SizeVector &strides,
|
||||||
|
const InferenceEngine::SizeVector &dims) {
|
||||||
|
if (strides.size() != dims.size())
|
||||||
|
return std::make_tuple(false, 0);
|
||||||
|
|
||||||
|
size_t dim = 1;
|
||||||
|
|
||||||
|
for (size_t i = dims.size(); i-- > 0;) {
|
||||||
|
if (strides[i] != dim)
|
||||||
|
return std::make_tuple(false, 0);
|
||||||
|
dim *= dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(true, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isCompatibleTensors(const InferenceEngine::TensorDesc &lhs, const InferenceEngine::TensorDesc &rhs) {
|
||||||
|
auto const &lhsBlockingDesc = lhs.getBlockingDesc();
|
||||||
|
auto const &rhsBlockingDesc = rhs.getBlockingDesc();
|
||||||
|
|
||||||
|
bool lhsDefaultStrides, rhsDefaultStrides;
|
||||||
|
size_t lhsSize, rhsSize;
|
||||||
|
|
||||||
|
std::tie(lhsDefaultStrides, lhsSize) = isDefaultStrides(lhsBlockingDesc.getStrides(), lhs.getDims());
|
||||||
|
std::tie(rhsDefaultStrides, rhsSize) = isDefaultStrides(rhsBlockingDesc.getStrides(), rhs.getDims());
|
||||||
|
|
||||||
|
return lhs.getPrecision() == rhs.getPrecision()
|
||||||
|
&& lhsSize == rhsSize
|
||||||
|
&& lhsDefaultStrides
|
||||||
|
&& rhsDefaultStrides
|
||||||
|
&& isDefaultOrder(lhsBlockingDesc.getOrder())
|
||||||
|
&& isDefaultOrder(rhsBlockingDesc.getOrder());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void MKLDNNInputNode::execute(mkldnn::stream strm) {
|
void MKLDNNInputNode::execute(mkldnn::stream strm) {
|
||||||
if (!constBlob)
|
if (!constBlob)
|
||||||
return;
|
return;
|
||||||
@ -106,7 +149,8 @@ void MKLDNNInputNode::execute(mkldnn::stream strm) {
|
|||||||
THROW_IE_EXCEPTION << "Incorrect blob sizes for node " << getName();
|
THROW_IE_EXCEPTION << "Incorrect blob sizes for node " << getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (constBlob->getTensorDesc() == dstBlob->getTensorDesc()) {
|
if (constBlob->getTensorDesc() == dstBlob->getTensorDesc()
|
||||||
|
|| isCompatibleTensors(constBlob->getTensorDesc(), dstBlob->getTensorDesc())) {
|
||||||
const int8_t *srcData = constBlob->cbuffer().as<int8_t *>();
|
const int8_t *srcData = constBlob->cbuffer().as<int8_t *>();
|
||||||
int8_t *dstData = dstBlob->buffer();
|
int8_t *dstData = dstBlob->buffer();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user