Friendly names fix for ONNX models (#12412)

This commit is contained in:
Tomasz Dołbniak
2022-08-08 21:56:03 +02:00
committed by GitHub
parent 3068b3823c
commit a1bd02e633
3 changed files with 10 additions and 6 deletions

View File

@@ -312,7 +312,11 @@ std::shared_ptr<Function> Graph::create_function() {
auto function = std::make_shared<Function>(get_ng_outputs(), m_parameters, get_name());
const auto& onnx_outputs = m_model->get_graph().output();
for (std::size_t i{0}; i < function->get_output_size(); ++i) {
function->get_output_op(i)->set_friendly_name(onnx_outputs.Get(i).name() + "/sink_port_0");
const auto& result_node = function->get_output_op(i);
const std::string onnx_output_name = onnx_outputs.Get(i).name();
result_node->set_friendly_name(onnx_output_name + "/sink_port_0");
const auto& previous_operation = result_node->get_input_node_shared_ptr(0);
previous_operation->set_friendly_name(onnx_output_name);
}
return function;
}

View File

@@ -96,7 +96,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_node_names_check) {
EXPECT_EQ(additions.size(), 2);
EXPECT_EQ(additions.at(0)->get_friendly_name(), "add_node1");
EXPECT_EQ(additions.at(0)->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"X"});
EXPECT_EQ(additions.at(1)->get_friendly_name(), "add_node2");
EXPECT_EQ(additions.at(1)->get_friendly_name(), "Y");
EXPECT_EQ(additions.at(1)->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"Y"});
}

View File

@@ -117,11 +117,11 @@ NGRAPH_TEST(onnx_tensor_names, simple_multiout_named_operator) {
// in this case both Results are connected directly to the MaxPool node
const auto result1 = find_by_friendly_name<op::Result>(ops, "y/sink_port_0");
EXPECT_NE(result1, nullptr);
EXPECT_EQ(result1->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "max_pool_node");
EXPECT_EQ(result1->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "z");
const auto result2 = find_by_friendly_name<op::Result>(ops, "z/sink_port_0");
EXPECT_NE(result2, nullptr);
EXPECT_EQ(result2->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "max_pool_node");
EXPECT_EQ(result2->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "z");
}
NGRAPH_TEST(onnx_tensor_names, subgraph_with_multiple_nodes_named) {
@@ -132,11 +132,11 @@ NGRAPH_TEST(onnx_tensor_names, subgraph_with_multiple_nodes_named) {
const auto result1 = find_by_friendly_name<op::Result>(ops, "y/sink_port_0");
EXPECT_NE(result1, nullptr);
EXPECT_EQ(result1->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "max_pool_node_y");
EXPECT_EQ(result1->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "y");
const auto result2 = find_by_friendly_name<op::Result>(ops, "z/sink_port_0");
EXPECT_NE(result2, nullptr);
EXPECT_EQ(result2->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "max_pool_node_z");
EXPECT_EQ(result2->input(0).get_source_output().get_node_shared_ptr()->get_friendly_name(), "z");
}
NGRAPH_TEST(onnx_tensor_names, subgraph_conv_with_bias) {