[Snippets][CPU] MKLDNNSnippetNode adopts canBeInPlace logics from eltwise node (#10334)
This commit is contained in:
parent
788a5bb9f2
commit
68c390f679
@ -201,39 +201,26 @@ bool MKLDNNSnippetNode::created() const {
|
||||
return getType() == Subgraph;
|
||||
}
|
||||
|
||||
// internal interface for subgraph execution
|
||||
bool MKLDNNSnippetNode::canBeInPlace() const {
|
||||
if (getParentEdgesAtPort(0)[0]->getParent()->getType() == Input) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static size_t argmax_rank(const std::vector<MKLDNNEdgeWeakPtr> &childEdges) {
|
||||
auto getOutBlockedDims = [childEdges](int i) {
|
||||
if ( auto childEdge = childEdges[i].lock() )
|
||||
return childEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getBlockDims();
|
||||
else
|
||||
IE_THROW() << "Unable to lock childEdge weak_ptr";
|
||||
return VectorDims{};
|
||||
};
|
||||
auto getOutRank = [getOutBlockedDims](int i) {
|
||||
return getOutBlockedDims(i).size();
|
||||
};
|
||||
size_t max_rank_idx = 0;
|
||||
size_t max_rank_val = getOutRank(0);
|
||||
for (size_t i = 1; i < childEdges.size(); i++) {
|
||||
const auto i_rank_val = getOutRank(i);
|
||||
if (max_rank_val < i_rank_val) {
|
||||
max_rank_idx = i;
|
||||
max_rank_val = i_rank_val;
|
||||
} else if (max_rank_val == i_rank_val) {
|
||||
const auto max_rank_dims = getOutBlockedDims(max_rank_idx);
|
||||
const auto i_dims = getOutBlockedDims(i);
|
||||
for (size_t j = 0; j < max_rank_val; j++) {
|
||||
if (i_dims[j] > max_rank_dims[j]) {
|
||||
max_rank_idx = i;
|
||||
max_rank_val = i_rank_val;
|
||||
break;
|
||||
for (auto& parentEdge : getParentEdges()) {
|
||||
auto parent = parentEdge.lock()->getParent();
|
||||
if (parent->getChildEdges().size() != 1)
|
||||
return false;
|
||||
|
||||
// WA to prevent memory corruption caused by inplace feature
|
||||
if (parent->getType() == Concatenation) {
|
||||
for (auto& parentParentEdge : parent->getParentEdges()) {
|
||||
auto parentParent = parentParentEdge.lock()->getParent();
|
||||
if (parentParent->getChildEdges().size() != 1)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return max_rank_idx;
|
||||
return getInputShapeAtPort(0) == getOutputShapeAtPort(0);
|
||||
}
|
||||
|
||||
static void offset_calculation(std::vector<size_t>& offset, const std::vector<size_t>& dims_in, const std::vector<size_t>& dims_out) {
|
||||
|
@ -32,6 +32,7 @@ public:
|
||||
// Here we convert to canonical for & jit everything
|
||||
void createPrimitive() override;
|
||||
|
||||
bool canBeInPlace() const override;
|
||||
bool created() const override;
|
||||
|
||||
// if generator is set, it would execute generated code otherwise it would fallback to nGraph reference
|
||||
|
Loading…
Reference in New Issue
Block a user