Changed add_output API (#8833)

* Changed add_output API

* Added C++ tests

* Added bython tests

* Try to fix python tests

* Fixed python tests

* Try to fix build

* Cannot compare pointers
This commit is contained in:
Ilya Churaev 2021-11-27 11:08:48 +03:00 committed by GitHub
parent 909dea8b5d
commit 7cb378c92b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 24 deletions

View File

@ -115,9 +115,9 @@ public:
ov::Output<const ov::Node> input(size_t i) const;
ov::Output<const ov::Node> input(const std::string& tensor_name) const;
void add_output(const std::string& tensor_name);
void add_output(const std::string& op_name, size_t output_idx);
void add_output(const ov::Output<ov::Node>& port);
ov::Output<ov::Node> add_output(const std::string& tensor_name);
ov::Output<ov::Node> add_output(const std::string& op_name, size_t output_idx);
ov::Output<ov::Node> add_output(const ov::Output<ov::Node>& port);
void reshape(const std::map<std::string, ov::PartialShape>& partial_shapes);
void reshape(const std::map<ov::Output<ov::Node>, ov::PartialShape>& partial_shapes);

View File

@ -868,22 +868,21 @@ void ov::Function::reshape(const std::map<ov::Output<ov::Node>, ov::PartialShape
}
}
void ov::Function::add_output(const std::string& tensor_name) {
ov::Output<ov::Node> ov::Function::add_output(const std::string& tensor_name) {
for (const auto& op : get_ops()) {
if (ov::op::util::is_output(op))
continue;
for (const auto& output : op->outputs()) {
const auto& names = output.get_tensor().get_names();
if (names.find(tensor_name) != names.end()) {
add_output(output);
return;
return add_output(output);
}
}
}
throw ov::Exception("Tensor name " + tensor_name + " was not found.");
}
void ov::Function::add_output(const std::string& op_name, size_t output_idx) {
ov::Output<ov::Node> ov::Function::add_output(const std::string& op_name, size_t output_idx) {
for (const auto& op : get_ops()) {
if (op->get_friendly_name() == op_name) {
OPENVINO_ASSERT(output_idx < op->get_output_size(),
@ -894,23 +893,23 @@ void ov::Function::add_output(const std::string& op_name, size_t output_idx) {
" has only ",
std::to_string(op->get_output_size()),
" outputs.");
add_output(op->output(output_idx));
return;
return add_output(op->output(output_idx));
}
}
throw ov::Exception("Port " + std::to_string(output_idx) + " for operation with name " + op_name +
" was not found.");
}
void ov::Function::add_output(const ov::Output<ov::Node>& port) {
ov::Output<ov::Node> ov::Function::add_output(const ov::Output<ov::Node>& port) {
if (ov::op::util::is_output(port.get_node()))
return;
return port;
for (const auto& input : port.get_target_inputs()) {
// Do not add result if port is already connected with result
if (ov::op::util::is_output(input.get_node())) {
return;
return input.get_node()->output(0);
}
}
auto result = std::make_shared<ov::op::v0::Result>(port);
add_results({result});
return result->output(0);
}

View File

@ -830,11 +830,14 @@ TEST(function, add_output_tensor_name) {
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_NO_THROW(f->add_output("relu_t1"));
ov::Output<ov::Node> out1, out2;
EXPECT_NO_THROW(out1 = f->add_output("relu_t1"));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_NO_THROW(f->add_output("relu_t1"));
EXPECT_NO_THROW(out2 = f->add_output("relu_t1"));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
EXPECT_EQ(out1, out2);
EXPECT_EQ(out1.get_node(), f->get_results()[1].get());
}
TEST(function, add_output_op_name) {
@ -880,7 +883,9 @@ TEST(function, add_output_port) {
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_NO_THROW(f->add_output(relu1->output(0)));
ov::Output<ov::Node> out;
EXPECT_NO_THROW(out = f->add_output(relu1->output(0)));
EXPECT_EQ(out.get_node(), f->get_results()[1].get());
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
}
@ -966,8 +971,10 @@ TEST(function, add_output_port_to_result) {
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_NO_THROW(f->add_output(result->output(0)));
ov::Output<ov::Node> out;
EXPECT_NO_THROW(out = f->add_output(result->output(0)));
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_EQ(out, result->output(0));
}
namespace {

View File

@ -589,6 +589,7 @@ void regclass_graph_Function(py::module m) {
"add_outputs",
[](ov::Function& self, py::handle& outputs) {
int i = 0;
std::vector<ov::Output<ov::Node>> new_outputs;
py::list _outputs;
if (!py::isinstance<py::list>(outputs)) {
if (py::isinstance<py::str>(outputs)) {
@ -605,19 +606,22 @@ void regclass_graph_Function(py::module m) {
}
for (py::handle output : _outputs) {
ov::Output<ov::Node> out;
if (py::isinstance<py::str>(_outputs[i])) {
self.add_output(output.cast<std::string>());
out = self.add_output(output.cast<std::string>());
} else if (py::isinstance<py::tuple>(output)) {
py::tuple output_tuple = output.cast<py::tuple>();
self.add_output(output_tuple[0].cast<std::string>(), output_tuple[1].cast<int>());
out = self.add_output(output_tuple[0].cast<std::string>(), output_tuple[1].cast<int>());
} else if (py::isinstance<ov::Output<ov::Node>>(_outputs[i])) {
self.add_output(output.cast<ov::Output<ov::Node>>());
out = self.add_output(output.cast<ov::Output<ov::Node>>());
} else {
throw py::type_error("Incorrect type of a value to add as output at index " + std::to_string(i) +
".");
}
new_outputs.emplace_back(out);
i++;
}
return new_outputs;
},
py::arg("outputs"));

View File

@ -21,10 +21,13 @@ def test_function_add_outputs_tensor_name():
relu2 = ops.relu(relu1, name="relu2")
function = Function(relu2, [param], "TestFunction")
assert len(function.get_results()) == 1
function.add_outputs("relu_t1")
new_outs = function.add_outputs("relu_t1")
assert len(function.get_results()) == 2
assert isinstance(function.outputs[1].get_tensor(), DescriptorTensor)
assert "relu_t1" in function.outputs[1].get_tensor().names
assert len(new_outs) == 1
assert new_outs[0].get_node() == function.outputs[1].get_node()
assert new_outs[0].get_index() == function.outputs[1].get_index()
def test_function_add_outputs_op_name():
@ -35,8 +38,11 @@ def test_function_add_outputs_op_name():
relu2 = ops.relu(relu1, name="relu2")
function = Function(relu2, [param], "TestFunction")
assert len(function.get_results()) == 1
function.add_outputs(("relu1", 0))
new_outs = function.add_outputs(("relu1", 0))
assert len(function.get_results()) == 2
assert len(new_outs) == 1
assert new_outs[0].get_node() == function.outputs[1].get_node()
assert new_outs[0].get_index() == function.outputs[1].get_index()
def test_function_add_output_port():
@ -47,8 +53,11 @@ def test_function_add_output_port():
relu2 = ops.relu(relu1, name="relu2")
function = Function(relu2, [param], "TestFunction")
assert len(function.get_results()) == 1
function.add_outputs(relu1.output(0))
new_outs = function.add_outputs(relu1.output(0))
assert len(function.get_results()) == 2
assert len(new_outs) == 1
assert new_outs[0].get_node() == function.outputs[1].get_node()
assert new_outs[0].get_index() == function.outputs[1].get_index()
def test_function_add_output_incorrect_tensor_name():
@ -102,8 +111,13 @@ def test_add_outputs_several_tensors():
relu3 = ops.relu(relu2, name="relu3")
function = Function(relu3, [param], "TestFunction")
assert len(function.get_results()) == 1
function.add_outputs(["relu_t1", "relu_t2"])
new_outs = function.add_outputs(["relu_t1", "relu_t2"])
assert len(function.get_results()) == 3
assert len(new_outs) == 2
assert new_outs[0].get_node() == function.outputs[1].get_node()
assert new_outs[0].get_index() == function.outputs[1].get_index()
assert new_outs[1].get_node() == function.outputs[2].get_node()
assert new_outs[1].get_index() == function.outputs[2].get_index()
def test_add_outputs_several_ports():
@ -116,8 +130,13 @@ def test_add_outputs_several_ports():
relu3 = ops.relu(relu2, name="relu3")
function = Function(relu3, [param], "TestFunction")
assert len(function.get_results()) == 1
function.add_outputs([("relu1", 0), ("relu2", 0)])
new_outs = function.add_outputs([("relu1", 0), ("relu2", 0)])
assert len(function.get_results()) == 3
assert len(new_outs) == 2
assert new_outs[0].get_node() == function.outputs[1].get_node()
assert new_outs[0].get_index() == function.outputs[1].get_index()
assert new_outs[1].get_node() == function.outputs[2].get_node()
assert new_outs[1].get_index() == function.outputs[2].get_index()
def test_add_outputs_incorrect_value():