diff --git a/ngraph/core/include/openvino/core/function.hpp b/ngraph/core/include/openvino/core/function.hpp index c14904c0ed4..9fb61bbc853 100644 --- a/ngraph/core/include/openvino/core/function.hpp +++ b/ngraph/core/include/openvino/core/function.hpp @@ -115,9 +115,9 @@ public: ov::Output input(size_t i) const; ov::Output input(const std::string& tensor_name) const; - void add_output(const std::string& tensor_name); - void add_output(const std::string& op_name, size_t output_idx); - void add_output(const ov::Output& port); + ov::Output add_output(const std::string& tensor_name); + ov::Output add_output(const std::string& op_name, size_t output_idx); + ov::Output add_output(const ov::Output& port); void reshape(const std::map& partial_shapes); void reshape(const std::map, ov::PartialShape>& partial_shapes); diff --git a/ngraph/core/src/function.cpp b/ngraph/core/src/function.cpp index 2e8c35683cd..88b4355ea6c 100644 --- a/ngraph/core/src/function.cpp +++ b/ngraph/core/src/function.cpp @@ -868,22 +868,21 @@ void ov::Function::reshape(const std::map, ov::PartialShape } } -void ov::Function::add_output(const std::string& tensor_name) { +ov::Output ov::Function::add_output(const std::string& tensor_name) { for (const auto& op : get_ops()) { if (ov::op::util::is_output(op)) continue; for (const auto& output : op->outputs()) { const auto& names = output.get_tensor().get_names(); if (names.find(tensor_name) != names.end()) { - add_output(output); - return; + return add_output(output); } } } throw ov::Exception("Tensor name " + tensor_name + " was not found."); } -void ov::Function::add_output(const std::string& op_name, size_t output_idx) { +ov::Output ov::Function::add_output(const std::string& op_name, size_t output_idx) { for (const auto& op : get_ops()) { if (op->get_friendly_name() == op_name) { OPENVINO_ASSERT(output_idx < op->get_output_size(), @@ -894,23 +893,23 @@ void ov::Function::add_output(const std::string& op_name, size_t output_idx) { " has only ", std::to_string(op->get_output_size()), " outputs."); - add_output(op->output(output_idx)); - return; + return add_output(op->output(output_idx)); } } throw ov::Exception("Port " + std::to_string(output_idx) + " for operation with name " + op_name + " was not found."); } -void ov::Function::add_output(const ov::Output& port) { +ov::Output ov::Function::add_output(const ov::Output& port) { if (ov::op::util::is_output(port.get_node())) - return; + return port; for (const auto& input : port.get_target_inputs()) { // Do not add result if port is already connected with result if (ov::op::util::is_output(input.get_node())) { - return; + return input.get_node()->output(0); } } auto result = std::make_shared(port); add_results({result}); + return result->output(0); } diff --git a/ngraph/test/function.cpp b/ngraph/test/function.cpp index 213e027dba0..6d417359600 100644 --- a/ngraph/test/function.cpp +++ b/ngraph/test/function.cpp @@ -830,11 +830,14 @@ TEST(function, add_output_tensor_name) { EXPECT_EQ(f->get_results().size(), 1); - EXPECT_NO_THROW(f->add_output("relu_t1")); + ov::Output out1, out2; + EXPECT_NO_THROW(out1 = f->add_output("relu_t1")); EXPECT_EQ(f->get_results().size(), 2); - EXPECT_NO_THROW(f->add_output("relu_t1")); + EXPECT_NO_THROW(out2 = f->add_output("relu_t1")); EXPECT_EQ(f->get_results().size(), 2); EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get()); + EXPECT_EQ(out1, out2); + EXPECT_EQ(out1.get_node(), f->get_results()[1].get()); } TEST(function, add_output_op_name) { @@ -880,7 +883,9 @@ TEST(function, add_output_port) { EXPECT_EQ(f->get_results().size(), 1); - EXPECT_NO_THROW(f->add_output(relu1->output(0))); + ov::Output out; + EXPECT_NO_THROW(out = f->add_output(relu1->output(0))); + EXPECT_EQ(out.get_node(), f->get_results()[1].get()); EXPECT_EQ(f->get_results().size(), 2); EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get()); } @@ -966,8 +971,10 @@ TEST(function, add_output_port_to_result) { EXPECT_EQ(f->get_results().size(), 1); - EXPECT_NO_THROW(f->add_output(result->output(0))); + ov::Output out; + EXPECT_NO_THROW(out = f->add_output(result->output(0))); EXPECT_EQ(f->get_results().size(), 1); + EXPECT_EQ(out, result->output(0)); } namespace { diff --git a/runtime/bindings/python/src/pyopenvino/graph/function.cpp b/runtime/bindings/python/src/pyopenvino/graph/function.cpp index 1a692f5cdd1..60443aceff5 100644 --- a/runtime/bindings/python/src/pyopenvino/graph/function.cpp +++ b/runtime/bindings/python/src/pyopenvino/graph/function.cpp @@ -589,6 +589,7 @@ void regclass_graph_Function(py::module m) { "add_outputs", [](ov::Function& self, py::handle& outputs) { int i = 0; + std::vector> new_outputs; py::list _outputs; if (!py::isinstance(outputs)) { if (py::isinstance(outputs)) { @@ -605,19 +606,22 @@ void regclass_graph_Function(py::module m) { } for (py::handle output : _outputs) { + ov::Output out; if (py::isinstance(_outputs[i])) { - self.add_output(output.cast()); + out = self.add_output(output.cast()); } else if (py::isinstance(output)) { py::tuple output_tuple = output.cast(); - self.add_output(output_tuple[0].cast(), output_tuple[1].cast()); + out = self.add_output(output_tuple[0].cast(), output_tuple[1].cast()); } else if (py::isinstance>(_outputs[i])) { - self.add_output(output.cast>()); + out = self.add_output(output.cast>()); } else { throw py::type_error("Incorrect type of a value to add as output at index " + std::to_string(i) + "."); } + new_outputs.emplace_back(out); i++; } + return new_outputs; }, py::arg("outputs")); diff --git a/runtime/bindings/python/tests/test_inference_engine/test_function.py b/runtime/bindings/python/tests/test_inference_engine/test_function.py index 3d2bca9180e..6929ca56c47 100644 --- a/runtime/bindings/python/tests/test_inference_engine/test_function.py +++ b/runtime/bindings/python/tests/test_inference_engine/test_function.py @@ -21,10 +21,13 @@ def test_function_add_outputs_tensor_name(): relu2 = ops.relu(relu1, name="relu2") function = Function(relu2, [param], "TestFunction") assert len(function.get_results()) == 1 - function.add_outputs("relu_t1") + new_outs = function.add_outputs("relu_t1") assert len(function.get_results()) == 2 assert isinstance(function.outputs[1].get_tensor(), DescriptorTensor) assert "relu_t1" in function.outputs[1].get_tensor().names + assert len(new_outs) == 1 + assert new_outs[0].get_node() == function.outputs[1].get_node() + assert new_outs[0].get_index() == function.outputs[1].get_index() def test_function_add_outputs_op_name(): @@ -35,8 +38,11 @@ def test_function_add_outputs_op_name(): relu2 = ops.relu(relu1, name="relu2") function = Function(relu2, [param], "TestFunction") assert len(function.get_results()) == 1 - function.add_outputs(("relu1", 0)) + new_outs = function.add_outputs(("relu1", 0)) assert len(function.get_results()) == 2 + assert len(new_outs) == 1 + assert new_outs[0].get_node() == function.outputs[1].get_node() + assert new_outs[0].get_index() == function.outputs[1].get_index() def test_function_add_output_port(): @@ -47,8 +53,11 @@ def test_function_add_output_port(): relu2 = ops.relu(relu1, name="relu2") function = Function(relu2, [param], "TestFunction") assert len(function.get_results()) == 1 - function.add_outputs(relu1.output(0)) + new_outs = function.add_outputs(relu1.output(0)) assert len(function.get_results()) == 2 + assert len(new_outs) == 1 + assert new_outs[0].get_node() == function.outputs[1].get_node() + assert new_outs[0].get_index() == function.outputs[1].get_index() def test_function_add_output_incorrect_tensor_name(): @@ -102,8 +111,13 @@ def test_add_outputs_several_tensors(): relu3 = ops.relu(relu2, name="relu3") function = Function(relu3, [param], "TestFunction") assert len(function.get_results()) == 1 - function.add_outputs(["relu_t1", "relu_t2"]) + new_outs = function.add_outputs(["relu_t1", "relu_t2"]) assert len(function.get_results()) == 3 + assert len(new_outs) == 2 + assert new_outs[0].get_node() == function.outputs[1].get_node() + assert new_outs[0].get_index() == function.outputs[1].get_index() + assert new_outs[1].get_node() == function.outputs[2].get_node() + assert new_outs[1].get_index() == function.outputs[2].get_index() def test_add_outputs_several_ports(): @@ -116,8 +130,13 @@ def test_add_outputs_several_ports(): relu3 = ops.relu(relu2, name="relu3") function = Function(relu3, [param], "TestFunction") assert len(function.get_results()) == 1 - function.add_outputs([("relu1", 0), ("relu2", 0)]) + new_outs = function.add_outputs([("relu1", 0), ("relu2", 0)]) assert len(function.get_results()) == 3 + assert len(new_outs) == 2 + assert new_outs[0].get_node() == function.outputs[1].get_node() + assert new_outs[0].get_index() == function.outputs[1].get_index() + assert new_outs[1].get_node() == function.outputs[2].get_node() + assert new_outs[1].get_index() == function.outputs[2].get_index() def test_add_outputs_incorrect_value():