Fixed collisions in friendly names for internal operations (#9965)
* Fixed collisions in friendly names for internal operations * Fixed renaming * Added comments * Renamed test file * Fix behavior for outputs * Fixed logic * Fixed comments
This commit is contained in:
@@ -12,16 +12,16 @@ namespace pass {
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ResolveGeneratedNameCollisions transformation helps to fix names collisions
|
||||
* if some autogenerated name has a conflict with the name from the original graph
|
||||
* @brief ResolveNameCollisions transformation helps to fix names collisions
|
||||
* if some internal nodes or nodes with autogenerated names have conflicts with other nodes from the original graph
|
||||
*
|
||||
* Every transformation call can change the graph structure and create some additional operations,
|
||||
* autogenerated name is used if new operation doesn't have friendly name.
|
||||
* This transformations should be called after the transformation pipeline in order to fix names collisions.
|
||||
*/
|
||||
class TRANSFORMATIONS_API ResolveGeneratedNameCollisions : public ModelPass {
|
||||
class TRANSFORMATIONS_API ResolveNameCollisions : public ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ResolveGeneratedNameCollisions", "0");
|
||||
OPENVINO_RTTI("ResolveNameCollisions", "0");
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
|
||||
};
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
|
||||
#include "transformations/resolve_gen_names_collisions.hpp"
|
||||
|
||||
bool ov::pass::ResolveGeneratedNameCollisions::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
||||
// Next containers are used to fix collisions in autogenerated names
|
||||
// Collect all unique friendly names
|
||||
std::unordered_set<std::string> unique_friendly_names;
|
||||
// Save nodes with autogenerated names but without conflicts in the candidate list
|
||||
std::unordered_map<std::string, Node*> nodes_with_possible_conflicts;
|
||||
// The final list of nodes with collisions
|
||||
std::vector<Node*> nodes_with_conflicts;
|
||||
|
||||
for (auto& node : model->get_ordered_ops()) {
|
||||
// Detect names collisions only for nodes with autogenerated names
|
||||
const auto friendly_name = node->get_friendly_name();
|
||||
if (unique_friendly_names.find(friendly_name) == unique_friendly_names.end()) {
|
||||
unique_friendly_names.insert(friendly_name);
|
||||
if (node->m_friendly_name.empty())
|
||||
nodes_with_possible_conflicts[friendly_name] = node.get();
|
||||
} else if (node->m_friendly_name.empty()) {
|
||||
// We have a conflict with autogenerated name
|
||||
nodes_with_conflicts.emplace_back(node.get());
|
||||
} else if (nodes_with_possible_conflicts.find(friendly_name) != nodes_with_possible_conflicts.end()) {
|
||||
// We have a conflict with autogenerated name
|
||||
nodes_with_conflicts.emplace_back(nodes_with_possible_conflicts[friendly_name]);
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve names collisions
|
||||
for (const auto& node : nodes_with_conflicts) {
|
||||
size_t idx = 2;
|
||||
const auto friendly_name = node->get_friendly_name();
|
||||
while (unique_friendly_names.find(friendly_name + "_" + std::to_string(idx)) != unique_friendly_names.end())
|
||||
idx++;
|
||||
const auto new_friendly_name = friendly_name + "_" + std::to_string(idx);
|
||||
node->set_friendly_name(new_friendly_name);
|
||||
unique_friendly_names.insert(new_friendly_name);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
|
||||
#include "transformations/resolve_names_collisions.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/sink.hpp"
|
||||
|
||||
bool ov::pass::ResolveNameCollisions::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
||||
// Next containers are used to fix collisions in autogenerated names
|
||||
// The final list of nodes with collisions
|
||||
std::vector<Node*> nodes_with_conflicts;
|
||||
std::unordered_map<std::string, std::list<Node*>> visited_nodes;
|
||||
|
||||
for (const auto& node : model->get_ordered_ops()) {
|
||||
// Detect names collisions only for nodes with autogenerated names
|
||||
const auto& friendly_name = node->get_friendly_name();
|
||||
visited_nodes[friendly_name].emplace_back(node.get());
|
||||
}
|
||||
|
||||
for (const auto& l_nodes : visited_nodes) {
|
||||
if (l_nodes.second.size() == 1)
|
||||
continue;
|
||||
const size_t nodes_size = l_nodes.second.size();
|
||||
bool has_public_node = false; // Parameter, Result ans Sinks
|
||||
size_t i(0);
|
||||
for (auto* node : l_nodes.second) {
|
||||
i++;
|
||||
// Skip the last node if we don't have public nodes with collisions
|
||||
if (i == nodes_size && !has_public_node)
|
||||
break;
|
||||
if (dynamic_cast<const ov::op::v0::Result*>(node)) {
|
||||
// Result is a service node
|
||||
continue;
|
||||
}
|
||||
if (dynamic_cast<const ov::op::Sink*>(node) ||
|
||||
dynamic_cast<const ov::op::v0::Parameter*>(node)) {
|
||||
// Resolve names for public ops with autogenerated name
|
||||
if (node->m_friendly_name.empty())
|
||||
nodes_with_conflicts.emplace_back(node);
|
||||
has_public_node = true;
|
||||
continue;
|
||||
} else {
|
||||
// For result we need to avoid changes in previous operation
|
||||
bool is_public = false;
|
||||
for (const auto& output : node->outputs()) {
|
||||
for (const auto input : output.get_target_inputs()) {
|
||||
if (dynamic_cast<const ov::op::v0::Result*>(input.get_node())) {
|
||||
has_public_node = true;
|
||||
is_public = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_public)
|
||||
break;
|
||||
}
|
||||
if (is_public)
|
||||
continue;
|
||||
}
|
||||
nodes_with_conflicts.emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve names collisions
|
||||
for (auto* node : nodes_with_conflicts) {
|
||||
size_t idx = 2;
|
||||
const auto friendly_name = node->get_friendly_name();
|
||||
while (visited_nodes.find(friendly_name + "_" + std::to_string(idx)) != visited_nodes.end())
|
||||
idx++;
|
||||
const auto new_friendly_name = friendly_name + "_" + std::to_string(idx);
|
||||
node->set_friendly_name(new_friendly_name);
|
||||
visited_nodes[new_friendly_name].emplace_back(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ struct AutoBroadcastSpec;
|
||||
} // namespace op
|
||||
namespace pass {
|
||||
|
||||
class ResolveGeneratedNameCollisions;
|
||||
class ResolveNameCollisions;
|
||||
|
||||
namespace pattern {
|
||||
class Matcher;
|
||||
@@ -127,7 +127,7 @@ class OPENVINO_API Node : public std::enable_shared_from_this<Node> {
|
||||
|
||||
friend class Model;
|
||||
// To fix collisions in generated friendly name
|
||||
friend class pass::ResolveGeneratedNameCollisions;
|
||||
friend class pass::ResolveNameCollisions;
|
||||
|
||||
protected:
|
||||
descriptor::Input& get_input_descriptor(size_t position);
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||
|
||||
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
|
||||
#include "transformations/resolve_gen_names_collisions.hpp"
|
||||
#include "transformations/resolve_names_collisions.hpp"
|
||||
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||
@@ -437,7 +437,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
}
|
||||
return !config.enable_loop_unrolling;
|
||||
});
|
||||
manager.register_pass<ov::pass::ResolveGeneratedNameCollisions>();
|
||||
manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||
|
||||
manager.run_passes(func);
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/resolve_gen_names_collisions.hpp"
|
||||
#include "transformations/resolve_names_collisions.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
TEST(ResolveGeneratedNameCollisionsTest, FixGeneratedNames) {
|
||||
TEST(ResolveNameCollisionsTest, FixGeneratedNames) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
||||
const auto gen_friendly_name = arg0->get_friendly_name();
|
||||
|
||||
@@ -31,14 +31,14 @@ TEST(ResolveGeneratedNameCollisionsTest, FixGeneratedNames) {
|
||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||
|
||||
ov::pass::Manager pass_manager;
|
||||
pass_manager.register_pass<ov::pass::ResolveGeneratedNameCollisions>();
|
||||
pass_manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||
pass_manager.run_passes(model);
|
||||
EXPECT_EQ(name, arg0->get_friendly_name());
|
||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name());
|
||||
EXPECT_EQ(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||
}
|
||||
|
||||
TEST(ResolveGeneratedNameCollisionsTest, DoNotFixFriendlyNames) {
|
||||
TEST(ResolveNameCollisionsTest, DoNotFixFriendlyNamesForParameters) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
||||
const auto gen_friendly_name = arg0->get_friendly_name();
|
||||
|
||||
@@ -57,9 +57,32 @@ TEST(ResolveGeneratedNameCollisionsTest, DoNotFixFriendlyNames) {
|
||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||
|
||||
ov::pass::Manager pass_manager;
|
||||
pass_manager.register_pass<ov::pass::ResolveGeneratedNameCollisions>();
|
||||
pass_manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||
pass_manager.run_passes(model);
|
||||
EXPECT_EQ(gen_friendly_name, arg0->get_friendly_name());
|
||||
EXPECT_EQ(arg1->get_friendly_name(), arg0->get_friendly_name());
|
||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||
}
|
||||
|
||||
TEST(ResolveNameCollisionsTest, FixFriendlyNamesForInternalOperations) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
||||
const auto gen_friendly_name = arg0->get_friendly_name();
|
||||
|
||||
|
||||
auto arg1 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 2, 3, 3});
|
||||
|
||||
auto concat1 = std::make_shared<ov::opset8::Concat>(ov::NodeVector{arg0, arg1}, 1);
|
||||
concat1->set_friendly_name("concat");
|
||||
auto concat = std::make_shared<ov::opset8::Concat>(ov::NodeVector{concat1, arg1}, 1);
|
||||
concat->set_friendly_name("concat");
|
||||
auto result1 = std::make_shared<ov::opset8::Result>(concat);
|
||||
|
||||
auto model = std::make_shared<ov::Model>(ov::ResultVector{result1}, ov::ParameterVector{arg0, arg1});
|
||||
|
||||
EXPECT_EQ(concat->get_friendly_name(), concat1->get_friendly_name());
|
||||
|
||||
ov::pass::Manager pass_manager;
|
||||
pass_manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||
pass_manager.run_passes(model);
|
||||
EXPECT_NE(concat->get_friendly_name(), concat1->get_friendly_name());
|
||||
}
|
||||
Reference in New Issue
Block a user