Don't add a new Result operation if output port is already connected to Result (#3934)

This commit is contained in:
Ilya Churaev 2021-01-21 13:42:32 +03:00 committed by GitHub
parent 187813a3f6
commit 88b200ea5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 0 deletions

View File

@ -225,6 +225,11 @@ StatusCode CNNNetworkNGraphImpl::addOutput(const std::string& layerName, size_t
try { try {
for (const auto & layer : _ngraph_function->get_ops()) { for (const auto & layer : _ngraph_function->get_ops()) {
if (layer->get_friendly_name() == layerName) { if (layer->get_friendly_name() == layerName) {
// Check that we don't have a result for the output port
for (const auto& port : layer->output(outputIndex).get_target_inputs()) {
if (dynamic_cast<ngraph::op::Result*>(port.get_node()))
return OK;
}
auto result = make_shared<::ngraph::op::Result>(layer->output(outputIndex)); auto result = make_shared<::ngraph::op::Result>(layer->output(outputIndex));
_ngraph_function->add_results({result}); _ngraph_function->add_results({result});

View File

@ -313,6 +313,41 @@ TEST(CNNNGraphImplTests, TestAddOutput) {
ASSERT_TRUE(outputs.find(testLayerName) != outputs.end()); ASSERT_TRUE(outputs.find(testLayerName) != outputs.end());
} }
TEST(CNNNGraphImplTests, TestAddOutputTwoTimes) {
const std::string testLayerName = "testReLU";
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
auto relu = std::make_shared<ngraph::op::Relu>(param);
relu->set_friendly_name(testLayerName);
auto relu2 = std::make_shared<ngraph::op::Relu>(relu);
relu2->set_friendly_name("relu2");
auto result = std::make_shared<ngraph::op::Result>(relu2);
ngraph::ParameterVector params = {param};
ngraph::ResultVector results = {result};
ngraph = std::make_shared<ngraph::Function>(results, params);
}
InferenceEngine::CNNNetwork cnnNet(ngraph);
ASSERT_NE(nullptr, cnnNet.getFunction());
ASSERT_EQ(4, cnnNet.layerCount());
cnnNet.addOutput(testLayerName);
ASSERT_NE(nullptr, cnnNet.getFunction());
ASSERT_EQ(5, cnnNet.layerCount());
cnnNet.addOutput(testLayerName);
ASSERT_NE(nullptr, cnnNet.getFunction());
ASSERT_EQ(5, cnnNet.layerCount());
auto outputs = cnnNet.getOutputsInfo();
ASSERT_EQ(2, outputs.size());
ASSERT_TRUE(outputs.find("relu2") != outputs.end());
ASSERT_TRUE(outputs.find(testLayerName) != outputs.end());
}
TEST(CNNNGraphImplTests, TestAddOutputFromConvertedNetwork) { TEST(CNNNGraphImplTests, TestAddOutputFromConvertedNetwork) {
const std::string testLayerName = "testReLU"; const std::string testLayerName = "testReLU";
std::shared_ptr<ngraph::Function> ngraph; std::shared_ptr<ngraph::Function> ngraph;