Fixed collisions in friendly names for internal operations (#9965)
* Fixed collisions in friendly names for internal operations * Fixed renaming * Added comments * Renamed test file * Fix behavior for outputs * Fixed logic * Fixed comments
This commit is contained in:
@@ -12,16 +12,16 @@ namespace pass {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_transformation_common_api
|
* @ingroup ie_transformation_common_api
|
||||||
* @brief ResolveGeneratedNameCollisions transformation helps to fix names collisions
|
* @brief ResolveNameCollisions transformation helps to fix names collisions
|
||||||
* if some autogenerated name has a conflict with the name from the original graph
|
* if some internal nodes or nodes with autogenerated names have conflicts with other nodes from the original graph
|
||||||
*
|
*
|
||||||
* Every transformation call can change the graph structure and create some additional operations,
|
* Every transformation call can change the graph structure and create some additional operations,
|
||||||
* autogenerated name is used if new operation doesn't have friendly name.
|
* autogenerated name is used if new operation doesn't have friendly name.
|
||||||
* This transformations should be called after the transformation pipeline in order to fix names collisions.
|
* This transformations should be called after the transformation pipeline in order to fix names collisions.
|
||||||
*/
|
*/
|
||||||
class TRANSFORMATIONS_API ResolveGeneratedNameCollisions : public ModelPass {
|
class TRANSFORMATIONS_API ResolveNameCollisions : public ModelPass {
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("ResolveGeneratedNameCollisions", "0");
|
OPENVINO_RTTI("ResolveNameCollisions", "0");
|
||||||
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
|
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
// Copyright (C) 2018-2022 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "transformations/resolve_gen_names_collisions.hpp"
|
|
||||||
|
|
||||||
bool ov::pass::ResolveGeneratedNameCollisions::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
|
||||||
// Next containers are used to fix collisions in autogenerated names
|
|
||||||
// Collect all unique friendly names
|
|
||||||
std::unordered_set<std::string> unique_friendly_names;
|
|
||||||
// Save nodes with autogenerated names but without conflicts in the candidate list
|
|
||||||
std::unordered_map<std::string, Node*> nodes_with_possible_conflicts;
|
|
||||||
// The final list of nodes with collisions
|
|
||||||
std::vector<Node*> nodes_with_conflicts;
|
|
||||||
|
|
||||||
for (auto& node : model->get_ordered_ops()) {
|
|
||||||
// Detect names collisions only for nodes with autogenerated names
|
|
||||||
const auto friendly_name = node->get_friendly_name();
|
|
||||||
if (unique_friendly_names.find(friendly_name) == unique_friendly_names.end()) {
|
|
||||||
unique_friendly_names.insert(friendly_name);
|
|
||||||
if (node->m_friendly_name.empty())
|
|
||||||
nodes_with_possible_conflicts[friendly_name] = node.get();
|
|
||||||
} else if (node->m_friendly_name.empty()) {
|
|
||||||
// We have a conflict with autogenerated name
|
|
||||||
nodes_with_conflicts.emplace_back(node.get());
|
|
||||||
} else if (nodes_with_possible_conflicts.find(friendly_name) != nodes_with_possible_conflicts.end()) {
|
|
||||||
// We have a conflict with autogenerated name
|
|
||||||
nodes_with_conflicts.emplace_back(nodes_with_possible_conflicts[friendly_name]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve names collisions
|
|
||||||
for (const auto& node : nodes_with_conflicts) {
|
|
||||||
size_t idx = 2;
|
|
||||||
const auto friendly_name = node->get_friendly_name();
|
|
||||||
while (unique_friendly_names.find(friendly_name + "_" + std::to_string(idx)) != unique_friendly_names.end())
|
|
||||||
idx++;
|
|
||||||
const auto new_friendly_name = friendly_name + "_" + std::to_string(idx);
|
|
||||||
node->set_friendly_name(new_friendly_name);
|
|
||||||
unique_friendly_names.insert(new_friendly_name);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "transformations/resolve_names_collisions.hpp"
|
||||||
|
#include "openvino/op/parameter.hpp"
|
||||||
|
#include "openvino/op/result.hpp"
|
||||||
|
#include "openvino/op/sink.hpp"
|
||||||
|
|
||||||
|
bool ov::pass::ResolveNameCollisions::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
||||||
|
// Next containers are used to fix collisions in autogenerated names
|
||||||
|
// The final list of nodes with collisions
|
||||||
|
std::vector<Node*> nodes_with_conflicts;
|
||||||
|
std::unordered_map<std::string, std::list<Node*>> visited_nodes;
|
||||||
|
|
||||||
|
for (const auto& node : model->get_ordered_ops()) {
|
||||||
|
// Detect names collisions only for nodes with autogenerated names
|
||||||
|
const auto& friendly_name = node->get_friendly_name();
|
||||||
|
visited_nodes[friendly_name].emplace_back(node.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& l_nodes : visited_nodes) {
|
||||||
|
if (l_nodes.second.size() == 1)
|
||||||
|
continue;
|
||||||
|
const size_t nodes_size = l_nodes.second.size();
|
||||||
|
bool has_public_node = false; // Parameter, Result ans Sinks
|
||||||
|
size_t i(0);
|
||||||
|
for (auto* node : l_nodes.second) {
|
||||||
|
i++;
|
||||||
|
// Skip the last node if we don't have public nodes with collisions
|
||||||
|
if (i == nodes_size && !has_public_node)
|
||||||
|
break;
|
||||||
|
if (dynamic_cast<const ov::op::v0::Result*>(node)) {
|
||||||
|
// Result is a service node
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (dynamic_cast<const ov::op::Sink*>(node) ||
|
||||||
|
dynamic_cast<const ov::op::v0::Parameter*>(node)) {
|
||||||
|
// Resolve names for public ops with autogenerated name
|
||||||
|
if (node->m_friendly_name.empty())
|
||||||
|
nodes_with_conflicts.emplace_back(node);
|
||||||
|
has_public_node = true;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// For result we need to avoid changes in previous operation
|
||||||
|
bool is_public = false;
|
||||||
|
for (const auto& output : node->outputs()) {
|
||||||
|
for (const auto input : output.get_target_inputs()) {
|
||||||
|
if (dynamic_cast<const ov::op::v0::Result*>(input.get_node())) {
|
||||||
|
has_public_node = true;
|
||||||
|
is_public = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (is_public)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (is_public)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
nodes_with_conflicts.emplace_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve names collisions
|
||||||
|
for (auto* node : nodes_with_conflicts) {
|
||||||
|
size_t idx = 2;
|
||||||
|
const auto friendly_name = node->get_friendly_name();
|
||||||
|
while (visited_nodes.find(friendly_name + "_" + std::to_string(idx)) != visited_nodes.end())
|
||||||
|
idx++;
|
||||||
|
const auto new_friendly_name = friendly_name + "_" + std::to_string(idx);
|
||||||
|
node->set_friendly_name(new_friendly_name);
|
||||||
|
visited_nodes[new_friendly_name].emplace_back(node);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ struct AutoBroadcastSpec;
|
|||||||
} // namespace op
|
} // namespace op
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class ResolveGeneratedNameCollisions;
|
class ResolveNameCollisions;
|
||||||
|
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
class Matcher;
|
class Matcher;
|
||||||
@@ -127,7 +127,7 @@ class OPENVINO_API Node : public std::enable_shared_from_this<Node> {
|
|||||||
|
|
||||||
friend class Model;
|
friend class Model;
|
||||||
// To fix collisions in generated friendly name
|
// To fix collisions in generated friendly name
|
||||||
friend class pass::ResolveGeneratedNameCollisions;
|
friend class pass::ResolveNameCollisions;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
descriptor::Input& get_input_descriptor(size_t position);
|
descriptor::Input& get_input_descriptor(size_t position);
|
||||||
|
|||||||
@@ -29,7 +29,7 @@
|
|||||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||||
|
|
||||||
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
|
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
|
||||||
#include "transformations/resolve_gen_names_collisions.hpp"
|
#include "transformations/resolve_names_collisions.hpp"
|
||||||
|
|
||||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||||
@@ -437,7 +437,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
|||||||
}
|
}
|
||||||
return !config.enable_loop_unrolling;
|
return !config.enable_loop_unrolling;
|
||||||
});
|
});
|
||||||
manager.register_pass<ov::pass::ResolveGeneratedNameCollisions>();
|
manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||||
|
|
||||||
manager.run_passes(func);
|
manager.run_passes(func);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,13 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "transformations/resolve_gen_names_collisions.hpp"
|
#include "transformations/resolve_names_collisions.hpp"
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "openvino/opsets/opset8.hpp"
|
#include "openvino/opsets/opset8.hpp"
|
||||||
#include "openvino/pass/manager.hpp"
|
#include "openvino/pass/manager.hpp"
|
||||||
|
|
||||||
TEST(ResolveGeneratedNameCollisionsTest, FixGeneratedNames) {
|
TEST(ResolveNameCollisionsTest, FixGeneratedNames) {
|
||||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
||||||
const auto gen_friendly_name = arg0->get_friendly_name();
|
const auto gen_friendly_name = arg0->get_friendly_name();
|
||||||
|
|
||||||
@@ -31,14 +31,14 @@ TEST(ResolveGeneratedNameCollisionsTest, FixGeneratedNames) {
|
|||||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||||
|
|
||||||
ov::pass::Manager pass_manager;
|
ov::pass::Manager pass_manager;
|
||||||
pass_manager.register_pass<ov::pass::ResolveGeneratedNameCollisions>();
|
pass_manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||||
pass_manager.run_passes(model);
|
pass_manager.run_passes(model);
|
||||||
EXPECT_EQ(name, arg0->get_friendly_name());
|
EXPECT_EQ(name, arg0->get_friendly_name());
|
||||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name());
|
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name());
|
||||||
EXPECT_EQ(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
EXPECT_EQ(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResolveGeneratedNameCollisionsTest, DoNotFixFriendlyNames) {
|
TEST(ResolveNameCollisionsTest, DoNotFixFriendlyNamesForParameters) {
|
||||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
||||||
const auto gen_friendly_name = arg0->get_friendly_name();
|
const auto gen_friendly_name = arg0->get_friendly_name();
|
||||||
|
|
||||||
@@ -57,9 +57,32 @@ TEST(ResolveGeneratedNameCollisionsTest, DoNotFixFriendlyNames) {
|
|||||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||||
|
|
||||||
ov::pass::Manager pass_manager;
|
ov::pass::Manager pass_manager;
|
||||||
pass_manager.register_pass<ov::pass::ResolveGeneratedNameCollisions>();
|
pass_manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||||
pass_manager.run_passes(model);
|
pass_manager.run_passes(model);
|
||||||
EXPECT_EQ(gen_friendly_name, arg0->get_friendly_name());
|
EXPECT_EQ(gen_friendly_name, arg0->get_friendly_name());
|
||||||
EXPECT_EQ(arg1->get_friendly_name(), arg0->get_friendly_name());
|
EXPECT_EQ(arg1->get_friendly_name(), arg0->get_friendly_name());
|
||||||
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
EXPECT_NE(arg1->get_friendly_name(), arg0->get_friendly_name() + "_2");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ResolveNameCollisionsTest, FixFriendlyNamesForInternalOperations) {
|
||||||
|
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 3, 3});
|
||||||
|
const auto gen_friendly_name = arg0->get_friendly_name();
|
||||||
|
|
||||||
|
|
||||||
|
auto arg1 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1, 2, 3, 3});
|
||||||
|
|
||||||
|
auto concat1 = std::make_shared<ov::opset8::Concat>(ov::NodeVector{arg0, arg1}, 1);
|
||||||
|
concat1->set_friendly_name("concat");
|
||||||
|
auto concat = std::make_shared<ov::opset8::Concat>(ov::NodeVector{concat1, arg1}, 1);
|
||||||
|
concat->set_friendly_name("concat");
|
||||||
|
auto result1 = std::make_shared<ov::opset8::Result>(concat);
|
||||||
|
|
||||||
|
auto model = std::make_shared<ov::Model>(ov::ResultVector{result1}, ov::ParameterVector{arg0, arg1});
|
||||||
|
|
||||||
|
EXPECT_EQ(concat->get_friendly_name(), concat1->get_friendly_name());
|
||||||
|
|
||||||
|
ov::pass::Manager pass_manager;
|
||||||
|
pass_manager.register_pass<ov::pass::ResolveNameCollisions>();
|
||||||
|
pass_manager.run_passes(model);
|
||||||
|
EXPECT_NE(concat->get_friendly_name(), concat1->get_friendly_name());
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user