[Snippets][CPU] MKLDNNSnippetNode adopts canBeInPlace logics from eltwise node (#10334)

This commit is contained in:
Ivan Novoselov 2022-02-15 13:13:35 +03:00 committed by GitHub
parent 788a5bb9f2
commit 68c390f679
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 29 deletions

View File

@ -201,39 +201,26 @@ bool MKLDNNSnippetNode::created() const {
return getType() == Subgraph; return getType() == Subgraph;
} }
// internal interface for subgraph execution bool MKLDNNSnippetNode::canBeInPlace() const {
if (getParentEdgesAtPort(0)[0]->getParent()->getType() == Input) {
return false;
}
static size_t argmax_rank(const std::vector<MKLDNNEdgeWeakPtr> &childEdges) { for (auto& parentEdge : getParentEdges()) {
auto getOutBlockedDims = [childEdges](int i) { auto parent = parentEdge.lock()->getParent();
if ( auto childEdge = childEdges[i].lock() ) if (parent->getChildEdges().size() != 1)
return childEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getBlockDims(); return false;
else
IE_THROW() << "Unable to lock childEdge weak_ptr"; // WA to prevent memory corruption caused by inplace feature
return VectorDims{}; if (parent->getType() == Concatenation) {
}; for (auto& parentParentEdge : parent->getParentEdges()) {
auto getOutRank = [getOutBlockedDims](int i) { auto parentParent = parentParentEdge.lock()->getParent();
return getOutBlockedDims(i).size(); if (parentParent->getChildEdges().size() != 1)
}; return false;
size_t max_rank_idx = 0;
size_t max_rank_val = getOutRank(0);
for (size_t i = 1; i < childEdges.size(); i++) {
const auto i_rank_val = getOutRank(i);
if (max_rank_val < i_rank_val) {
max_rank_idx = i;
max_rank_val = i_rank_val;
} else if (max_rank_val == i_rank_val) {
const auto max_rank_dims = getOutBlockedDims(max_rank_idx);
const auto i_dims = getOutBlockedDims(i);
for (size_t j = 0; j < max_rank_val; j++) {
if (i_dims[j] > max_rank_dims[j]) {
max_rank_idx = i;
max_rank_val = i_rank_val;
break;
}
} }
} }
} }
return max_rank_idx; return getInputShapeAtPort(0) == getOutputShapeAtPort(0);
} }
static void offset_calculation(std::vector<size_t>& offset, const std::vector<size_t>& dims_in, const std::vector<size_t>& dims_out) { static void offset_calculation(std::vector<size_t>& offset, const std::vector<size_t>& dims_in, const std::vector<size_t>& dims_out) {

View File

@ -32,6 +32,7 @@ public:
// Here we convert to canonical for & jit everything // Here we convert to canonical for & jit everything
void createPrimitive() override; void createPrimitive() override;
bool canBeInPlace() const override;
bool created() const override; bool created() const override;
// if generator is set, it would execute generated code otherwise it would fallback to nGraph reference // if generator is set, it would execute generated code otherwise it would fallback to nGraph reference