[CPU] Fix edge memory share issue (#16202)

This commit is contained in:
Mang Guo 2023-03-27 05:20:51 -04:00 committed by GitHub
parent 6b70c449ba
commit 5e835e327b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 1 deletions

View File

@ -596,8 +596,11 @@ EdgePtr Edge::getBaseEdge(int look) {
for (auto &ch_edge : ch_edges) {
auto &chch_conf = ch_edge->getChild()->getSelectedPrimitiveDescriptor()->getConfig();
if (chch_conf.inConfs[ch_edge->getOutputNum()].inPlace() >= 0)
if (chch_conf.inConfs[ch_edge->getOutputNum()].inPlace() >= 0) {
next_ch_edge = ch_edge;
// To align with upstream-inplace, we stop searching once found the first inplace consumer
break;
}
}
return next_ch_edge->getBaseEdge(LOOK_DOWN);
} else if (parentConfig.outConfs[inputNum].inPlace() >= 0 && (look & LOOK_UP)) {
@ -614,6 +617,7 @@ EdgePtr Edge::getBaseEdge(int look) {
for (auto edge : edges_for_same_port) {
if (edge.get() != this) {
auto base = edge->getBaseEdge(LOOK_BOTH | LOOK_NO_RECURRENT);
// Return once found the first inplace consumer
if (base != edge && base != edges_for_same_port[0]) return base;
}
}

View File

@ -0,0 +1,66 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset8.hpp>
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace CPUTestUtils;
using namespace ngraph;
namespace SubgraphTestsDefinitions {
// Subgraph:
/*
* paramter1 parameter2
* \ /
* \ /
* Concat (inPlace)
* / | \
* / | \
* Reorder Reorder Reorder (the reorder nodes are optimized and use inplace memory mode)
* / | \
* / | \
* Multiply Multiply Multiply
* / | \
* / | \
* Result Result Result
*/
class ConcatReorderInPlaceTest : virtual public LayerTestsUtils::LayerTestsCommon {
public:
void SetUp() override {
const std::vector<size_t> inputShape = {1, 100, 1, 1};
auto inputParams = ngraph::builder::makeParams(ngraph::element::f32, {inputShape, inputShape});
auto concat = ngraph::builder::makeConcat(ngraph::OutputVector{inputParams[0], inputParams[1]}, 1);
const auto targetFormat = nhwc;
auto mul1 = std::make_shared<ngraph::opset8::Multiply>(
concat,
ngraph::builder::makeConstant(ngraph::element::f32, Shape{1}, std::vector<float>{4}));
mul1->get_rt_info() = CPUTestsBase::makeCPUInfo({targetFormat}, {targetFormat}, {});
auto mul2 = std::make_shared<ngraph::opset8::Multiply>(
concat,
ngraph::builder::makeConstant(ngraph::element::f32, Shape{1}, std::vector<float>{5}));
mul2->get_rt_info() = CPUTestsBase::makeCPUInfo({targetFormat}, {targetFormat}, {});
auto mul3 = std::make_shared<ngraph::opset8::Multiply>(
concat,
ngraph::builder::makeConstant(ngraph::element::f32, Shape{1}, std::vector<float>{6}));
mul3->get_rt_info() = CPUTestsBase::makeCPUInfo({targetFormat}, {targetFormat}, {});
ngraph::ResultVector results{std::make_shared<ngraph::opset8::Result>(mul1),
std::make_shared<ngraph::opset8::Result>(mul2),
std::make_shared<ngraph::opset8::Result>(mul3)};
function = std::make_shared<ngraph::Function>(results, inputParams, "ConcatReorderInPlace");
targetDevice = CommonTestUtils::DEVICE_CPU;
}
};
namespace {
TEST_F(ConcatReorderInPlaceTest, smoke_ConcatReorderInPlace_CPU) {
Run();
}
} // namespace
} // namespace SubgraphTestsDefinitions