Don't add a new Result operation if output port is already connected to Result (#3934)
This commit is contained in:
parent
187813a3f6
commit
88b200ea5b
@ -225,6 +225,11 @@ StatusCode CNNNetworkNGraphImpl::addOutput(const std::string& layerName, size_t
|
||||
try {
|
||||
for (const auto & layer : _ngraph_function->get_ops()) {
|
||||
if (layer->get_friendly_name() == layerName) {
|
||||
// Check that we don't have a result for the output port
|
||||
for (const auto& port : layer->output(outputIndex).get_target_inputs()) {
|
||||
if (dynamic_cast<ngraph::op::Result*>(port.get_node()))
|
||||
return OK;
|
||||
}
|
||||
auto result = make_shared<::ngraph::op::Result>(layer->output(outputIndex));
|
||||
_ngraph_function->add_results({result});
|
||||
|
||||
|
@ -313,6 +313,41 @@ TEST(CNNNGraphImplTests, TestAddOutput) {
|
||||
ASSERT_TRUE(outputs.find(testLayerName) != outputs.end());
|
||||
}
|
||||
|
||||
TEST(CNNNGraphImplTests, TestAddOutputTwoTimes) {
|
||||
const std::string testLayerName = "testReLU";
|
||||
std::shared_ptr<ngraph::Function> ngraph;
|
||||
{
|
||||
ngraph::PartialShape shape({1, 3, 22, 22});
|
||||
ngraph::element::Type type(ngraph::element::Type_t::f32);
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
|
||||
auto relu = std::make_shared<ngraph::op::Relu>(param);
|
||||
relu->set_friendly_name(testLayerName);
|
||||
auto relu2 = std::make_shared<ngraph::op::Relu>(relu);
|
||||
relu2->set_friendly_name("relu2");
|
||||
auto result = std::make_shared<ngraph::op::Result>(relu2);
|
||||
|
||||
ngraph::ParameterVector params = {param};
|
||||
ngraph::ResultVector results = {result};
|
||||
|
||||
ngraph = std::make_shared<ngraph::Function>(results, params);
|
||||
}
|
||||
|
||||
InferenceEngine::CNNNetwork cnnNet(ngraph);
|
||||
ASSERT_NE(nullptr, cnnNet.getFunction());
|
||||
ASSERT_EQ(4, cnnNet.layerCount());
|
||||
|
||||
cnnNet.addOutput(testLayerName);
|
||||
ASSERT_NE(nullptr, cnnNet.getFunction());
|
||||
ASSERT_EQ(5, cnnNet.layerCount());
|
||||
cnnNet.addOutput(testLayerName);
|
||||
ASSERT_NE(nullptr, cnnNet.getFunction());
|
||||
ASSERT_EQ(5, cnnNet.layerCount());
|
||||
auto outputs = cnnNet.getOutputsInfo();
|
||||
ASSERT_EQ(2, outputs.size());
|
||||
ASSERT_TRUE(outputs.find("relu2") != outputs.end());
|
||||
ASSERT_TRUE(outputs.find(testLayerName) != outputs.end());
|
||||
}
|
||||
|
||||
TEST(CNNNGraphImplTests, TestAddOutputFromConvertedNetwork) {
|
||||
const std::string testLayerName = "testReLU";
|
||||
std::shared_ptr<ngraph::Function> ngraph;
|
||||
|
Loading…
Reference in New Issue
Block a user