[Snippets][CPU] MKLDNNSnippetNode adopts canBeInPlace logics from eltwise node (#10334)

This commit is contained in:
Ivan Novoselov 2022-02-15 13:13:35 +03:00 committed by GitHub
parent 788a5bb9f2
commit 68c390f679
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 29 deletions

View File

@ -201,39 +201,26 @@ bool MKLDNNSnippetNode::created() const {
return getType() == Subgraph;
}
// internal interface for subgraph execution
bool MKLDNNSnippetNode::canBeInPlace() const {
if (getParentEdgesAtPort(0)[0]->getParent()->getType() == Input) {
return false;
}
static size_t argmax_rank(const std::vector<MKLDNNEdgeWeakPtr> &childEdges) {
auto getOutBlockedDims = [childEdges](int i) {
if ( auto childEdge = childEdges[i].lock() )
return childEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getBlockDims();
else
IE_THROW() << "Unable to lock childEdge weak_ptr";
return VectorDims{};
};
auto getOutRank = [getOutBlockedDims](int i) {
return getOutBlockedDims(i).size();
};
size_t max_rank_idx = 0;
size_t max_rank_val = getOutRank(0);
for (size_t i = 1; i < childEdges.size(); i++) {
const auto i_rank_val = getOutRank(i);
if (max_rank_val < i_rank_val) {
max_rank_idx = i;
max_rank_val = i_rank_val;
} else if (max_rank_val == i_rank_val) {
const auto max_rank_dims = getOutBlockedDims(max_rank_idx);
const auto i_dims = getOutBlockedDims(i);
for (size_t j = 0; j < max_rank_val; j++) {
if (i_dims[j] > max_rank_dims[j]) {
max_rank_idx = i;
max_rank_val = i_rank_val;
break;
for (auto& parentEdge : getParentEdges()) {
auto parent = parentEdge.lock()->getParent();
if (parent->getChildEdges().size() != 1)
return false;
// WA to prevent memory corruption caused by inplace feature
if (parent->getType() == Concatenation) {
for (auto& parentParentEdge : parent->getParentEdges()) {
auto parentParent = parentParentEdge.lock()->getParent();
if (parentParent->getChildEdges().size() != 1)
return false;
}
}
}
}
return max_rank_idx;
return getInputShapeAtPort(0) == getOutputShapeAtPort(0);
}
static void offset_calculation(std::vector<size_t>& offset, const std::vector<size_t>& dims_in, const std::vector<size_t>& dims_out) {

View File

@ -32,6 +32,7 @@ public:
// Here we convert to canonical for & jit everything
void createPrimitive() override;
bool canBeInPlace() const override;
bool created() const override;
// if generator is set, it would execute generated code otherwise it would fallback to nGraph reference