diff --git a/src/core/src/graph_util.cpp b/src/core/src/graph_util.cpp index 96b870bd45a..25f6ca04241 100644 --- a/src/core/src/graph_util.cpp +++ b/src/core/src/graph_util.cpp @@ -744,43 +744,50 @@ bool ngraph::check_for_cycles(const ngraph::Function* func, ngraph::NodeVector& } bool ov::replace_output_update_name(Output output, const Output& replacement) { - bool has_result_output = false; - for (auto& target_input : output.get_target_inputs()) { - if (ov::is_type(target_input.get_node())) { - // ignore trivial elimination - has_result_output = true; - if (ov::is_type(replacement.get_node())) { - return false; - } - break; + // output port consumers can be reconnected to replacement port only when: + // 1. output has no Result consumers (so we do not propagate node name) + // 2. output has Result consumers and single output port and replacement doesn't have Results consumers + // and has exactly one output port + // In all other cases output name will be lost or changed, so we don't perform the replacement + + auto has_result_consumers = [](const Output& port) { + const auto& consumers = port.get_target_inputs(); + return std::any_of(consumers.cbegin(), consumers.cend(), [](const Input& consumer) { + return ov::is_type(consumer.get_node()); + }); + }; + + bool preserve_legacy_output_name = false; + if (has_result_consumers(output)) { + preserve_legacy_output_name = true; + if (output.get_node()->get_output_size() != 1 || replacement.get_node()->get_output_size() != 1 || + is_type(replacement.get_node()) || has_result_consumers(replacement)) { + return false; } } - if (!has_result_output || replacement.get_node()->get_users().size() == 1) { - if (has_result_output && !ov::is_type(replacement.get_node())) { - replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name()); - // Update output tensor name - const auto output_tensor_name = output.get_tensor().get_name(); - if (!output_tensor_name.empty()) { - replacement.get_tensor().set_name(output_tensor_name); - } else { - replacement.get_tensor().set_name(output.get_node()->get_friendly_name()); - } + + if (preserve_legacy_output_name) { + replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name()); + // Update output tensor name + const auto& output_tensor_name = output.get_tensor().get_name(); + if (!output_tensor_name.empty()) { + replacement.get_tensor().set_name(output_tensor_name); + } else { + replacement.get_tensor().set_name(output.get_node()->get_friendly_name()); } - - // Save replacement tensor names before replacement as they will be overriden by the output tensor names - const auto output_names = replacement.get_tensor_ptr()->get_names(); - const auto tensor_name = replacement.get_tensor().get_name(); - output.replace(replacement); - - // Restore back original replacement tensor names - replacement.get_tensor().add_names(output_names); - replacement.get_tensor().set_name(tensor_name); - - copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()}, - replacement.get_node_shared_ptr()); - return true; } - return false; + + // Save replacement tensor name before replacement as they will be overridden by the output tensor name + const auto tensor_name = replacement.get_tensor().get_name(); + + output.replace(replacement); + + // Restore back original replacement tensor name + replacement.get_tensor().set_name(tensor_name); + + copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()}, + replacement.get_node_shared_ptr()); + return true; } bool ov::replace_node_update_name(const std::shared_ptr& target, const std::shared_ptr& replacement) { diff --git a/src/core/src/node_output.cpp b/src/core/src/node_output.cpp index e85531c7fce..0e986f50cea 100644 --- a/src/core/src/node_output.cpp +++ b/src/core/src/node_output.cpp @@ -7,6 +7,7 @@ #include "ngraph/log.hpp" #include "ngraph/variant.hpp" #include "openvino/core/node.hpp" +#include "openvino/op/parameter.hpp" namespace ov { Output::Output(Node* node, size_t index) : m_node(node->shared_from_this()), m_index(index) {} @@ -64,9 +65,24 @@ void Output::replace(const Output& replacement) { for (auto& input : get_target_inputs()) { input.replace_source_output(replacement); } - replacement.get_tensor_ptr()->set_names(get_tensor_ptr()->get_names()); + replacement.get_tensor_ptr()->add_names(get_tensor_ptr()->get_names()); NGRAPH_SUPPRESS_DEPRECATED_START - replacement.get_tensor_ptr()->set_name(get_tensor_ptr()->get_name()); + // In legacy API we rely on output port tensor name and use it as an input or output name for the model + // Due to m_name is just a string, and we can't store multiple aliases for single output port we have to + // handle two situations during replacement: + // 1. When we replace consumers to Parameter output port we can't change its name, so we skip this part + // 2. In other cases when we replace consumers to another output port we should set name. For example: + // if we eliminate Node2 from Node1->Node2->Result we have to set Node2 output port name to Node1 + // output port name, so the output name for model won't be changed. + // But there are some cases when output name can not be preserved, so the replacement shouldn't be used: + // 1. Parameter->Node->Result - if we eliminate Node we will lose output name + // 2. Node1-->Node2->Result - if we eliminate Node2 we will lose Result output name + // `->Result + // In both of these cases please use replace_output_update_name() method which automatically prevents the + // replacement for cases when we can not preserve input/output names of model. + if (!is_type(replacement.get_node())) { + replacement.get_tensor_ptr()->set_name(get_tensor_ptr()->get_name()); + } NGRAPH_SUPPRESS_DEPRECATED_END } diff --git a/src/core/tests/replace_node.cpp b/src/core/tests/replace_node.cpp index 0c5c1ad43d7..414a54fc727 100644 --- a/src/core/tests/replace_node.cpp +++ b/src/core/tests/replace_node.cpp @@ -117,7 +117,7 @@ TEST(replace_node, simple_node_replacement) { new_relu->output(0).get_tensor().set_names({"f"}); replace_node(relu, new_relu); - ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set({"c", "d"})); + ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set({"c", "d", "f"})); } TEST(replace_node, node_elimination) { @@ -133,6 +133,52 @@ TEST(replace_node, node_elimination) { ASSERT_EQ(param->output(0).get_tensor().get_names(), std::unordered_set({"a", "b"})); } +TEST(replace_node, node_elimination_1) { + auto param = std::make_shared(element::i64, Shape{3, 64}); + auto split = std::make_shared(param, op::Constant::create(element::i64, Shape{}, {0}), 3); + auto relu1 = std::make_shared(split->output(2)); + auto relu2 = std::make_shared(relu1); + auto result2 = std::make_shared(relu2); + + // relu1 can be removed because we don't have to preserve name + ASSERT_TRUE(replace_output_update_name(relu1->output(0), relu1->input_value(0))); + + // relu2 can't be removed because we have to preserve name and Split has more than one output port + ASSERT_FALSE(replace_output_update_name(relu2->output(0), relu2->input_value(0))); +} + +TEST(replace_node, node_elimination_2) { + auto param = std::make_shared(element::i64, Shape{3, 64}); + auto relu1 = std::make_shared(param); + auto result1 = std::make_shared(relu1); + auto relu2 = std::make_shared(relu1); + auto result2 = std::make_shared(relu2); + + // relu2 can't be removed because relu1 has Result as consumer + ASSERT_FALSE(replace_output_update_name(relu2->output(0), relu2->input_value(0))); +} + +TEST(replace_node, node_elimination_3) { + auto param = std::make_shared(element::i64, Shape{3, 64}); + auto relu1 = std::make_shared(param); + auto relu2 = std::make_shared(relu1); + auto relu3 = std::make_shared(relu1); + auto result2 = std::make_shared(relu3); + + // relu3 can be removed because relu1 has no Result as consumer + ASSERT_TRUE(replace_output_update_name(relu3->output(0), relu3->input_value(0))); +} + +TEST(replace_node, node_elimination_4) { + auto param = std::make_shared(element::i64, Shape{3, 64}); + auto relu1 = std::make_shared(param); + auto split = std::make_shared(relu1, op::Constant::create(element::i64, Shape{}, {0}), 3); + auto relu2 = std::make_shared(split->output(2)); + auto result2 = std::make_shared(relu2); + + ASSERT_TRUE(replace_output_update_name(split->output(2), split->input_value(0))); +} + TEST(replace_node, output_replacement) { auto param = std::make_shared(element::i64, Shape{1, 64}); param->output(0).get_tensor().set_names({"a", "b"}); @@ -144,7 +190,7 @@ TEST(replace_node, output_replacement) { relu->output(0).replace(new_relu->output(0)); - ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set({"c", "d"})); + ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set({"c", "d", "f"})); } TEST(replace_node, source_replacement) { diff --git a/src/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp b/src/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp index 1b43d55a088..4b072313c6e 100644 --- a/src/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp +++ b/src/tests/functional/inference_engine/ngraph_reader/linear_ops_tests.cpp @@ -1788,222 +1788,6 @@ TEST_F(NGraphReaderTests, RemoveAdd2) { }); } -TEST_F(NGraphReaderTests, RemoveAdd3) { - std::string model = R"V0G0N( - - - - - - - 1 - 64 - 112 - 112 - - - - - - - 1 - 64 - 112 - 112 - - - - - 1 - 64 - 112 - 112 - - - - - - - - 1 - 1 - 1 - - - - - - - 1 - 64 - 112 - 112 - - - 1 - 1 - 1 - - - - - 1 - 64 - 112 - 112 - - - - - - - - 1 - 1 - 1 - - - - - - - 1 - 64 - 112 - 112 - - - 1 - 1 - 1 - - - - - 1 - 64 - 112 - 112 - - - - - - - 1 - 64 - 112 - 112 - - - - - - - 1 - 64 - 112 - 112 - - - - - - - - - - - - - - -)V0G0N"; - std::string modelV5 = R"V0G0N( - - - - - - 1 - 64 - 112 - 112 - - - - - - - 1 - 64 - 112 - 112 - - - - - 1 - 64 - 112 - 112 - - - - - - - - 1 - 64 - 112 - 112 - - - - - 1 - 64 - 112 - 112 - - - - - - - - 1 - 64 - 112 - 112 - - - - - 1 - 64 - 112 - 112 - - - - - - - - - - -)V0G0N"; - compareIRs(model, modelV5, 10, [](Blob::Ptr& weights) { - // Set scale/shift constants - auto* scale = reinterpret_cast(weights->buffer().as() + 0); - scale[0] = 0; - scale[1] = 0; - }); -} - TEST_F(NGraphReaderTests, ConvertAddToEltwise2) { std::string model = R"V0G0N(