Fixed add_output for new subgraph (#16726)

This commit is contained in:
Ilya Churaev 2023-04-05 11:39:47 +04:00 committed by GitHub
parent c474f564a9
commit f2d4c96032
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 3 deletions

View File

@ -10,6 +10,7 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "openvino/core/any.hpp" #include "openvino/core/any.hpp"
@ -554,6 +555,7 @@ private:
// of weak_ptr not to increase node ref counter to prevent the situation when // of weak_ptr not to increase node ref counter to prevent the situation when
// node has no consumers but still exists in a graph. // node has no consumers but still exists in a graph.
mutable std::vector<std::weak_ptr<Node>> m_cached_ordered_ops; mutable std::vector<std::weak_ptr<Node>> m_cached_ordered_ops;
mutable std::unordered_set<Node*> m_cached_ops;
mutable std::unordered_map<std::string, Output<Node>> m_cached_output_names; mutable std::unordered_map<std::string, Output<Node>> m_cached_output_names;
mutable std::unordered_map<std::string, std::weak_ptr<Node>> m_cached_op_names; mutable std::unordered_map<std::string, std::weak_ptr<Node>> m_cached_op_names;

View File

@ -318,6 +318,7 @@ std::vector<shared_ptr<ov::Node>> ov::Model::get_ordered_ops() const {
m_cached_ordered_ops.clear(); m_cached_ordered_ops.clear();
for_each(order.cbegin(), order.cend(), [this](const shared_ptr<Node>& node) { for_each(order.cbegin(), order.cend(), [this](const shared_ptr<Node>& node) {
m_cached_ordered_ops.push_back(node); m_cached_ordered_ops.push_back(node);
m_cached_ops.insert(node.get());
node->insert_info(m_shared_rt_info); node->insert_info(m_shared_rt_info);
}); });
m_cached_output_names.clear(); m_cached_output_names.clear();
@ -923,6 +924,9 @@ ov::Output<ov::Node> ov::Model::add_output(const std::string& op_name, size_t ou
} }
ov::Output<ov::Node> ov::Model::add_output(const ov::Output<ov::Node>& port) { ov::Output<ov::Node> ov::Model::add_output(const ov::Output<ov::Node>& port) {
auto cache_valid = [&]() {
return m_cached_ops.count(port.get_node());
};
if (ov::op::util::is_output(port.get_node())) if (ov::op::util::is_output(port.get_node()))
return port; return port;
for (const auto& input : port.get_target_inputs()) { for (const auto& input : port.get_target_inputs()) {
@ -934,9 +938,14 @@ ov::Output<ov::Node> ov::Model::add_output(const ov::Output<ov::Node>& port) {
auto result = std::make_shared<ov::op::v0::Result>(port); auto result = std::make_shared<ov::op::v0::Result>(port);
m_results.push_back(result); m_results.push_back(result);
if (m_shared_rt_info->get_use_topological_cache()) { if (m_shared_rt_info->get_use_topological_cache()) {
if (cache_valid()) {
// Full update of topological cache is not needed, 'result' can be just inserted to the end // Full update of topological cache is not needed, 'result' can be just inserted to the end
m_cached_ordered_ops.push_back(result); m_cached_ordered_ops.push_back(result);
m_cached_ops.insert(result.get());
result->insert_info(m_shared_rt_info); // Just for consistency, not required for Result nodes result->insert_info(m_shared_rt_info); // Just for consistency, not required for Result nodes
} else {
m_shared_rt_info->set_use_topological_cache(false);
}
} }
return result->output(0); return result->output(0);
} }

View File

@ -1034,6 +1034,31 @@ TEST(model, add_output_port) {
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get()); EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
} }
TEST(model, add_output_to_new_subgraph) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto f = std::make_shared<ov::Model>(relu2, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
ov::Output<ov::Node> out;
EXPECT_NO_THROW(
out = f->add_output(ov::opset8::Constant::create(ov::element::i32, {1}, std::vector<int32_t>{1})->output(0)));
EXPECT_NO_THROW(f->get_ordered_ops());
EXPECT_EQ(out.get_node(), f->get_results()[1].get());
EXPECT_EQ(f->get_results().size(), 2);
}
TEST(model, add_output_incorrect_tensor_name) { TEST(model, add_output_incorrect_tensor_name) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1}); auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data"); arg0->set_friendly_name("data");