[Snippets][CPU] MKLDNNSnippetNode adopts canBeInPlace logics from eltwise node (#10334)
This commit is contained in:
parent
788a5bb9f2
commit
68c390f679
@ -201,39 +201,26 @@ bool MKLDNNSnippetNode::created() const {
|
|||||||
return getType() == Subgraph;
|
return getType() == Subgraph;
|
||||||
}
|
}
|
||||||
|
|
||||||
// internal interface for subgraph execution
|
bool MKLDNNSnippetNode::canBeInPlace() const {
|
||||||
|
if (getParentEdgesAtPort(0)[0]->getParent()->getType() == Input) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static size_t argmax_rank(const std::vector<MKLDNNEdgeWeakPtr> &childEdges) {
|
for (auto& parentEdge : getParentEdges()) {
|
||||||
auto getOutBlockedDims = [childEdges](int i) {
|
auto parent = parentEdge.lock()->getParent();
|
||||||
if ( auto childEdge = childEdges[i].lock() )
|
if (parent->getChildEdges().size() != 1)
|
||||||
return childEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getBlockDims();
|
return false;
|
||||||
else
|
|
||||||
IE_THROW() << "Unable to lock childEdge weak_ptr";
|
// WA to prevent memory corruption caused by inplace feature
|
||||||
return VectorDims{};
|
if (parent->getType() == Concatenation) {
|
||||||
};
|
for (auto& parentParentEdge : parent->getParentEdges()) {
|
||||||
auto getOutRank = [getOutBlockedDims](int i) {
|
auto parentParent = parentParentEdge.lock()->getParent();
|
||||||
return getOutBlockedDims(i).size();
|
if (parentParent->getChildEdges().size() != 1)
|
||||||
};
|
return false;
|
||||||
size_t max_rank_idx = 0;
|
|
||||||
size_t max_rank_val = getOutRank(0);
|
|
||||||
for (size_t i = 1; i < childEdges.size(); i++) {
|
|
||||||
const auto i_rank_val = getOutRank(i);
|
|
||||||
if (max_rank_val < i_rank_val) {
|
|
||||||
max_rank_idx = i;
|
|
||||||
max_rank_val = i_rank_val;
|
|
||||||
} else if (max_rank_val == i_rank_val) {
|
|
||||||
const auto max_rank_dims = getOutBlockedDims(max_rank_idx);
|
|
||||||
const auto i_dims = getOutBlockedDims(i);
|
|
||||||
for (size_t j = 0; j < max_rank_val; j++) {
|
|
||||||
if (i_dims[j] > max_rank_dims[j]) {
|
|
||||||
max_rank_idx = i;
|
|
||||||
max_rank_val = i_rank_val;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return max_rank_idx;
|
return getInputShapeAtPort(0) == getOutputShapeAtPort(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void offset_calculation(std::vector<size_t>& offset, const std::vector<size_t>& dims_in, const std::vector<size_t>& dims_out) {
|
static void offset_calculation(std::vector<size_t>& offset, const std::vector<size_t>& dims_in, const std::vector<size_t>& dims_out) {
|
||||||
|
@ -32,6 +32,7 @@ public:
|
|||||||
// Here we convert to canonical for & jit everything
|
// Here we convert to canonical for & jit everything
|
||||||
void createPrimitive() override;
|
void createPrimitive() override;
|
||||||
|
|
||||||
|
bool canBeInPlace() const override;
|
||||||
bool created() const override;
|
bool created() const override;
|
||||||
|
|
||||||
// if generator is set, it would execute generated code otherwise it would fallback to nGraph reference
|
// if generator is set, it would execute generated code otherwise it would fallback to nGraph reference
|
||||||
|
Loading…
Reference in New Issue
Block a user