Update replace_output_update_name method (#10000)
* Update replace_output_update_name method * Remove wrong legacy test
This commit is contained in:
@@ -744,43 +744,50 @@ bool ngraph::check_for_cycles(const ngraph::Function* func, ngraph::NodeVector&
|
||||
}
|
||||
|
||||
bool ov::replace_output_update_name(Output<Node> output, const Output<Node>& replacement) {
|
||||
bool has_result_output = false;
|
||||
for (auto& target_input : output.get_target_inputs()) {
|
||||
if (ov::is_type<op::v0::Result>(target_input.get_node())) {
|
||||
// ignore trivial elimination
|
||||
has_result_output = true;
|
||||
if (ov::is_type<ngraph::op::Parameter>(replacement.get_node())) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
// output port consumers can be reconnected to replacement port only when:
|
||||
// 1. output has no Result consumers (so we do not propagate node name)
|
||||
// 2. output has Result consumers and single output port and replacement doesn't have Results consumers
|
||||
// and has exactly one output port
|
||||
// In all other cases output name will be lost or changed, so we don't perform the replacement
|
||||
|
||||
auto has_result_consumers = [](const Output<Node>& port) {
|
||||
const auto& consumers = port.get_target_inputs();
|
||||
return std::any_of(consumers.cbegin(), consumers.cend(), [](const Input<Node>& consumer) {
|
||||
return ov::is_type<op::v0::Result>(consumer.get_node());
|
||||
});
|
||||
};
|
||||
|
||||
bool preserve_legacy_output_name = false;
|
||||
if (has_result_consumers(output)) {
|
||||
preserve_legacy_output_name = true;
|
||||
if (output.get_node()->get_output_size() != 1 || replacement.get_node()->get_output_size() != 1 ||
|
||||
is_type<ngraph::op::Parameter>(replacement.get_node()) || has_result_consumers(replacement)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!has_result_output || replacement.get_node()->get_users().size() == 1) {
|
||||
if (has_result_output && !ov::is_type<ngraph::op::Parameter>(replacement.get_node())) {
|
||||
replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
|
||||
// Update output tensor name
|
||||
const auto output_tensor_name = output.get_tensor().get_name();
|
||||
if (!output_tensor_name.empty()) {
|
||||
replacement.get_tensor().set_name(output_tensor_name);
|
||||
} else {
|
||||
replacement.get_tensor().set_name(output.get_node()->get_friendly_name());
|
||||
}
|
||||
|
||||
if (preserve_legacy_output_name) {
|
||||
replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
|
||||
// Update output tensor name
|
||||
const auto& output_tensor_name = output.get_tensor().get_name();
|
||||
if (!output_tensor_name.empty()) {
|
||||
replacement.get_tensor().set_name(output_tensor_name);
|
||||
} else {
|
||||
replacement.get_tensor().set_name(output.get_node()->get_friendly_name());
|
||||
}
|
||||
|
||||
// Save replacement tensor names before replacement as they will be overriden by the output tensor names
|
||||
const auto output_names = replacement.get_tensor_ptr()->get_names();
|
||||
const auto tensor_name = replacement.get_tensor().get_name();
|
||||
output.replace(replacement);
|
||||
|
||||
// Restore back original replacement tensor names
|
||||
replacement.get_tensor().add_names(output_names);
|
||||
replacement.get_tensor().set_name(tensor_name);
|
||||
|
||||
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
|
||||
replacement.get_node_shared_ptr());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
// Save replacement tensor name before replacement as they will be overridden by the output tensor name
|
||||
const auto tensor_name = replacement.get_tensor().get_name();
|
||||
|
||||
output.replace(replacement);
|
||||
|
||||
// Restore back original replacement tensor name
|
||||
replacement.get_tensor().set_name(tensor_name);
|
||||
|
||||
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
|
||||
replacement.get_node_shared_ptr());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ov::replace_node_update_name(const std::shared_ptr<Node>& target, const std::shared_ptr<Node>& replacement) {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/variant.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
|
||||
namespace ov {
|
||||
Output<Node>::Output(Node* node, size_t index) : m_node(node->shared_from_this()), m_index(index) {}
|
||||
@@ -64,9 +65,24 @@ void Output<Node>::replace(const Output<Node>& replacement) {
|
||||
for (auto& input : get_target_inputs()) {
|
||||
input.replace_source_output(replacement);
|
||||
}
|
||||
replacement.get_tensor_ptr()->set_names(get_tensor_ptr()->get_names());
|
||||
replacement.get_tensor_ptr()->add_names(get_tensor_ptr()->get_names());
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
replacement.get_tensor_ptr()->set_name(get_tensor_ptr()->get_name());
|
||||
// In legacy API we rely on output port tensor name and use it as an input or output name for the model
|
||||
// Due to m_name is just a string, and we can't store multiple aliases for single output port we have to
|
||||
// handle two situations during replacement:
|
||||
// 1. When we replace consumers to Parameter output port we can't change its name, so we skip this part
|
||||
// 2. In other cases when we replace consumers to another output port we should set name. For example:
|
||||
// if we eliminate Node2 from Node1->Node2->Result we have to set Node2 output port name to Node1
|
||||
// output port name, so the output name for model won't be changed.
|
||||
// But there are some cases when output name can not be preserved, so the replacement shouldn't be used:
|
||||
// 1. Parameter->Node->Result - if we eliminate Node we will lose output name
|
||||
// 2. Node1-->Node2->Result - if we eliminate Node2 we will lose Result output name
|
||||
// `->Result
|
||||
// In both of these cases please use replace_output_update_name() method which automatically prevents the
|
||||
// replacement for cases when we can not preserve input/output names of model.
|
||||
if (!is_type<ov::op::v0::Parameter>(replacement.get_node())) {
|
||||
replacement.get_tensor_ptr()->set_name(get_tensor_ptr()->get_name());
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ TEST(replace_node, simple_node_replacement) {
|
||||
new_relu->output(0).get_tensor().set_names({"f"});
|
||||
replace_node(relu, new_relu);
|
||||
|
||||
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
|
||||
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d", "f"}));
|
||||
}
|
||||
|
||||
TEST(replace_node, node_elimination) {
|
||||
@@ -133,6 +133,52 @@ TEST(replace_node, node_elimination) {
|
||||
ASSERT_EQ(param->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"a", "b"}));
|
||||
}
|
||||
|
||||
TEST(replace_node, node_elimination_1) {
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{3, 64});
|
||||
auto split = std::make_shared<op::v1::Split>(param, op::Constant::create(element::i64, Shape{}, {0}), 3);
|
||||
auto relu1 = std::make_shared<op::Relu>(split->output(2));
|
||||
auto relu2 = std::make_shared<op::Relu>(relu1);
|
||||
auto result2 = std::make_shared<op::Result>(relu2);
|
||||
|
||||
// relu1 can be removed because we don't have to preserve name
|
||||
ASSERT_TRUE(replace_output_update_name(relu1->output(0), relu1->input_value(0)));
|
||||
|
||||
// relu2 can't be removed because we have to preserve name and Split has more than one output port
|
||||
ASSERT_FALSE(replace_output_update_name(relu2->output(0), relu2->input_value(0)));
|
||||
}
|
||||
|
||||
TEST(replace_node, node_elimination_2) {
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{3, 64});
|
||||
auto relu1 = std::make_shared<op::Relu>(param);
|
||||
auto result1 = std::make_shared<op::Result>(relu1);
|
||||
auto relu2 = std::make_shared<op::Relu>(relu1);
|
||||
auto result2 = std::make_shared<op::Result>(relu2);
|
||||
|
||||
// relu2 can't be removed because relu1 has Result as consumer
|
||||
ASSERT_FALSE(replace_output_update_name(relu2->output(0), relu2->input_value(0)));
|
||||
}
|
||||
|
||||
TEST(replace_node, node_elimination_3) {
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{3, 64});
|
||||
auto relu1 = std::make_shared<op::Relu>(param);
|
||||
auto relu2 = std::make_shared<op::Relu>(relu1);
|
||||
auto relu3 = std::make_shared<op::Relu>(relu1);
|
||||
auto result2 = std::make_shared<op::Result>(relu3);
|
||||
|
||||
// relu3 can be removed because relu1 has no Result as consumer
|
||||
ASSERT_TRUE(replace_output_update_name(relu3->output(0), relu3->input_value(0)));
|
||||
}
|
||||
|
||||
TEST(replace_node, node_elimination_4) {
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{3, 64});
|
||||
auto relu1 = std::make_shared<op::Relu>(param);
|
||||
auto split = std::make_shared<op::v1::Split>(relu1, op::Constant::create(element::i64, Shape{}, {0}), 3);
|
||||
auto relu2 = std::make_shared<op::Relu>(split->output(2));
|
||||
auto result2 = std::make_shared<op::Result>(relu2);
|
||||
|
||||
ASSERT_TRUE(replace_output_update_name(split->output(2), split->input_value(0)));
|
||||
}
|
||||
|
||||
TEST(replace_node, output_replacement) {
|
||||
auto param = std::make_shared<op::Parameter>(element::i64, Shape{1, 64});
|
||||
param->output(0).get_tensor().set_names({"a", "b"});
|
||||
@@ -144,7 +190,7 @@ TEST(replace_node, output_replacement) {
|
||||
|
||||
relu->output(0).replace(new_relu->output(0));
|
||||
|
||||
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d"}));
|
||||
ASSERT_EQ(new_relu->output(0).get_tensor().get_names(), std::unordered_set<std::string>({"c", "d", "f"}));
|
||||
}
|
||||
|
||||
TEST(replace_node, source_replacement) {
|
||||
|
||||
@@ -1788,222 +1788,6 @@ TEST_F(NGraphReaderTests, RemoveAdd2) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(NGraphReaderTests, RemoveAdd3) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="Network" version="10">
|
||||
<layers>
|
||||
<layer id="0" name="data" type="Parameter" version="opset1">
|
||||
<data element_type="f32" shape="1,64,112,112"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="4" name="relu" type="ReLU" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="1" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="1" name="broadcast1_data" type="Const" version="opset1">
|
||||
<data element_type="f32" offset="0" shape="1,1,1" size="4"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="2" name="add" type="Add" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
<port id="1" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="6" name="broadcast2_data" type="Const" version="opset1">
|
||||
<data element_type="f32" offset="4" shape="1,1,1" size="4"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="5" name="add2" type="Add" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
<port id="1" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="output1" type="Result" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
<layer id="7" name="output2" type="Result" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="4" to-port="0"/>
|
||||
<edge from-layer="4" from-port="1" to-layer="2" to-port="0"/>
|
||||
<edge from-layer="1" from-port="0" to-layer="2" to-port="1"/>
|
||||
<edge from-layer="4" from-port="1" to-layer="5" to-port="0"/>
|
||||
<edge from-layer="6" from-port="0" to-layer="5" to-port="1"/>
|
||||
<edge from-layer="5" from-port="2" to-layer="3" to-port="0"/>
|
||||
<edge from-layer="2" from-port="2" to-layer="7" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
std::string modelV5 = R"V0G0N(
|
||||
<net name="Network" version="5" precision="FP32" batch="1">
|
||||
<layers>
|
||||
<layer id="0" name="data" precision="FP32" type="Input">
|
||||
<output>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="relu" precision="FP32" type="ReLU">
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="1">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="4" name="add" precision="FP32" type="Power">
|
||||
<data power="1.000000" scale="1.000000" shift="0.000000"/>
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="3">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="5" name="add2" precision="FP32" type="Power">
|
||||
<data power="1.000000" scale="1.000000" shift="0.000000"/>
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="3">
|
||||
<dim>1</dim>
|
||||
<dim>64</dim>
|
||||
<dim>112</dim>
|
||||
<dim>112</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="3" to-port="0"/>
|
||||
<edge from-layer="3" from-port="1" to-layer="4" to-port="0"/>
|
||||
<edge from-layer="3" from-port="1" to-layer="5" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
compareIRs(model, modelV5, 10, [](Blob::Ptr& weights) {
|
||||
// Set scale/shift constants
|
||||
auto* scale = reinterpret_cast<float *>(weights->buffer().as<int8_t*>() + 0);
|
||||
scale[0] = 0;
|
||||
scale[1] = 0;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(NGraphReaderTests, ConvertAddToEltwise2) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="Network" version="10">
|
||||
|
||||
Reference in New Issue
Block a user