From 68c390f6796f6eb2beed4d4f2438d21466347936 Mon Sep 17 00:00:00 2001 From: Ivan Novoselov Date: Tue, 15 Feb 2022 13:13:35 +0300 Subject: [PATCH] [Snippets][CPU] MKLDNNSnippetNode adopts canBeInPlace logics from eltwise node (#10334) --- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 45 +++++++------------- src/plugins/intel_cpu/src/nodes/subgraph.h | 1 + 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 46c588d9688..dced30ba677 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -201,39 +201,26 @@ bool MKLDNNSnippetNode::created() const { return getType() == Subgraph; } -// internal interface for subgraph execution +bool MKLDNNSnippetNode::canBeInPlace() const { + if (getParentEdgesAtPort(0)[0]->getParent()->getType() == Input) { + return false; + } -static size_t argmax_rank(const std::vector &childEdges) { - auto getOutBlockedDims = [childEdges](int i) { - if ( auto childEdge = childEdges[i].lock() ) - return childEdge->getMemory().GetDescWithType()->getBlockDims(); - else - IE_THROW() << "Unable to lock childEdge weak_ptr"; - return VectorDims{}; - }; - auto getOutRank = [getOutBlockedDims](int i) { - return getOutBlockedDims(i).size(); - }; - size_t max_rank_idx = 0; - size_t max_rank_val = getOutRank(0); - for (size_t i = 1; i < childEdges.size(); i++) { - const auto i_rank_val = getOutRank(i); - if (max_rank_val < i_rank_val) { - max_rank_idx = i; - max_rank_val = i_rank_val; - } else if (max_rank_val == i_rank_val) { - const auto max_rank_dims = getOutBlockedDims(max_rank_idx); - const auto i_dims = getOutBlockedDims(i); - for (size_t j = 0; j < max_rank_val; j++) { - if (i_dims[j] > max_rank_dims[j]) { - max_rank_idx = i; - max_rank_val = i_rank_val; - break; - } + for (auto& parentEdge : getParentEdges()) { + auto parent = parentEdge.lock()->getParent(); + if (parent->getChildEdges().size() != 1) + return false; + + // WA to prevent memory corruption caused by inplace feature + if (parent->getType() == Concatenation) { + for (auto& parentParentEdge : parent->getParentEdges()) { + auto parentParent = parentParentEdge.lock()->getParent(); + if (parentParent->getChildEdges().size() != 1) + return false; } } } - return max_rank_idx; + return getInputShapeAtPort(0) == getOutputShapeAtPort(0); } static void offset_calculation(std::vector& offset, const std::vector& dims_in, const std::vector& dims_out) { diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.h b/src/plugins/intel_cpu/src/nodes/subgraph.h index 2e1e95a2650..701f8f834c9 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.h +++ b/src/plugins/intel_cpu/src/nodes/subgraph.h @@ -32,6 +32,7 @@ public: // Here we convert to canonical for & jit everything void createPrimitive() override; + bool canBeInPlace() const override; bool created() const override; // if generator is set, it would execute generated code otherwise it would fallback to nGraph reference