Fixed add_output for new subgraph (#16726)
This commit is contained in:
parent
c474f564a9
commit
f2d4c96032
@ -10,6 +10,7 @@
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/core/any.hpp"
|
||||
@ -554,6 +555,7 @@ private:
|
||||
// of weak_ptr not to increase node ref counter to prevent the situation when
|
||||
// node has no consumers but still exists in a graph.
|
||||
mutable std::vector<std::weak_ptr<Node>> m_cached_ordered_ops;
|
||||
mutable std::unordered_set<Node*> m_cached_ops;
|
||||
|
||||
mutable std::unordered_map<std::string, Output<Node>> m_cached_output_names;
|
||||
mutable std::unordered_map<std::string, std::weak_ptr<Node>> m_cached_op_names;
|
||||
|
@ -318,6 +318,7 @@ std::vector<shared_ptr<ov::Node>> ov::Model::get_ordered_ops() const {
|
||||
m_cached_ordered_ops.clear();
|
||||
for_each(order.cbegin(), order.cend(), [this](const shared_ptr<Node>& node) {
|
||||
m_cached_ordered_ops.push_back(node);
|
||||
m_cached_ops.insert(node.get());
|
||||
node->insert_info(m_shared_rt_info);
|
||||
});
|
||||
m_cached_output_names.clear();
|
||||
@ -923,6 +924,9 @@ ov::Output<ov::Node> ov::Model::add_output(const std::string& op_name, size_t ou
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> ov::Model::add_output(const ov::Output<ov::Node>& port) {
|
||||
auto cache_valid = [&]() {
|
||||
return m_cached_ops.count(port.get_node());
|
||||
};
|
||||
if (ov::op::util::is_output(port.get_node()))
|
||||
return port;
|
||||
for (const auto& input : port.get_target_inputs()) {
|
||||
@ -934,9 +938,14 @@ ov::Output<ov::Node> ov::Model::add_output(const ov::Output<ov::Node>& port) {
|
||||
auto result = std::make_shared<ov::op::v0::Result>(port);
|
||||
m_results.push_back(result);
|
||||
if (m_shared_rt_info->get_use_topological_cache()) {
|
||||
// Full update of topological cache is not needed, 'result' can be just inserted to the end
|
||||
m_cached_ordered_ops.push_back(result);
|
||||
result->insert_info(m_shared_rt_info); // Just for consistency, not required for Result nodes
|
||||
if (cache_valid()) {
|
||||
// Full update of topological cache is not needed, 'result' can be just inserted to the end
|
||||
m_cached_ordered_ops.push_back(result);
|
||||
m_cached_ops.insert(result.get());
|
||||
result->insert_info(m_shared_rt_info); // Just for consistency, not required for Result nodes
|
||||
} else {
|
||||
m_shared_rt_info->set_use_topological_cache(false);
|
||||
}
|
||||
}
|
||||
return result->output(0);
|
||||
}
|
||||
|
@ -1034,6 +1034,31 @@ TEST(model, add_output_port) {
|
||||
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
|
||||
}
|
||||
|
||||
TEST(model, add_output_to_new_subgraph) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
arg0->set_friendly_name("data");
|
||||
arg0->get_output_tensor(0).set_names({"input"});
|
||||
|
||||
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
|
||||
relu1->set_friendly_name("relu1");
|
||||
relu1->get_output_tensor(0).set_names({"relu_t1"});
|
||||
|
||||
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
|
||||
relu2->set_friendly_name("relu2");
|
||||
relu2->get_output_tensor(0).set_names({"relu_t2"});
|
||||
auto f = std::make_shared<ov::Model>(relu2, ov::ParameterVector{arg0});
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
EXPECT_EQ(f->get_results().size(), 1);
|
||||
|
||||
ov::Output<ov::Node> out;
|
||||
EXPECT_NO_THROW(
|
||||
out = f->add_output(ov::opset8::Constant::create(ov::element::i32, {1}, std::vector<int32_t>{1})->output(0)));
|
||||
EXPECT_NO_THROW(f->get_ordered_ops());
|
||||
EXPECT_EQ(out.get_node(), f->get_results()[1].get());
|
||||
EXPECT_EQ(f->get_results().size(), 2);
|
||||
}
|
||||
|
||||
TEST(model, add_output_incorrect_tensor_name) {
|
||||
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
|
||||
arg0->set_friendly_name("data");
|
||||
|
Loading…
Reference in New Issue
Block a user