Changed add_output API (#8833)
* Changed add_output API * Added C++ tests * Added bython tests * Try to fix python tests * Fixed python tests * Try to fix build * Cannot compare pointers
This commit is contained in:
parent
909dea8b5d
commit
7cb378c92b
@ -115,9 +115,9 @@ public:
|
||||
ov::Output<const ov::Node> input(size_t i) const;
|
||||
ov::Output<const ov::Node> input(const std::string& tensor_name) const;
|
||||
|
||||
void add_output(const std::string& tensor_name);
|
||||
void add_output(const std::string& op_name, size_t output_idx);
|
||||
void add_output(const ov::Output<ov::Node>& port);
|
||||
ov::Output<ov::Node> add_output(const std::string& tensor_name);
|
||||
ov::Output<ov::Node> add_output(const std::string& op_name, size_t output_idx);
|
||||
ov::Output<ov::Node> add_output(const ov::Output<ov::Node>& port);
|
||||
|
||||
void reshape(const std::map<std::string, ov::PartialShape>& partial_shapes);
|
||||
void reshape(const std::map<ov::Output<ov::Node>, ov::PartialShape>& partial_shapes);
|
||||
|
@ -868,22 +868,21 @@ void ov::Function::reshape(const std::map<ov::Output<ov::Node>, ov::PartialShape
|
||||
}
|
||||
}
|
||||
|
||||
void ov::Function::add_output(const std::string& tensor_name) {
|
||||
ov::Output<ov::Node> ov::Function::add_output(const std::string& tensor_name) {
|
||||
for (const auto& op : get_ops()) {
|
||||
if (ov::op::util::is_output(op))
|
||||
continue;
|
||||
for (const auto& output : op->outputs()) {
|
||||
const auto& names = output.get_tensor().get_names();
|
||||
if (names.find(tensor_name) != names.end()) {
|
||||
add_output(output);
|
||||
return;
|
||||
return add_output(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
throw ov::Exception("Tensor name " + tensor_name + " was not found.");
|
||||
}
|
||||
|
||||
void ov::Function::add_output(const std::string& op_name, size_t output_idx) {
|
||||
ov::Output<ov::Node> ov::Function::add_output(const std::string& op_name, size_t output_idx) {
|
||||
for (const auto& op : get_ops()) {
|
||||
if (op->get_friendly_name() == op_name) {
|
||||
OPENVINO_ASSERT(output_idx < op->get_output_size(),
|
||||
@ -894,23 +893,23 @@ void ov::Function::add_output(const std::string& op_name, size_t output_idx) {
|
||||
" has only ",
|
||||
std::to_string(op->get_output_size()),
|
||||
" outputs.");
|
||||
add_output(op->output(output_idx));
|
||||
return;
|
||||
return add_output(op->output(output_idx));
|
||||
}
|
||||
}
|
||||
throw ov::Exception("Port " + std::to_string(output_idx) + " for operation with name " + op_name +
|
||||
" was not found.");
|
||||
}
|
||||
|
||||
void ov::Function::add_output(const ov::Output<ov::Node>& port) {
|
||||
ov::Output<ov::Node> ov::Function::add_output(const ov::Output<ov::Node>& port) {
|
||||
if (ov::op::util::is_output(port.get_node()))
|
||||
return;
|
||||
return port;
|
||||
for (const auto& input : port.get_target_inputs()) {
|
||||
// Do not add result if port is already connected with result
|
||||
if (ov::op::util::is_output(input.get_node())) {
|
||||
return;
|
||||
return input.get_node()->output(0);
|
||||
}
|
||||
}
|
||||
auto result = std::make_shared<ov::op::v0::Result>(port);
|
||||
add_results({result});
|
||||
return result->output(0);
|
||||
}
|
||||
|
@ -830,11 +830,14 @@ TEST(function, add_output_tensor_name) {
|
||||
|
||||
EXPECT_EQ(f->get_results().size(), 1);
|
||||
|
||||
EXPECT_NO_THROW(f->add_output("relu_t1"));
|
||||
ov::Output<ov::Node> out1, out2;
|
||||
EXPECT_NO_THROW(out1 = f->add_output("relu_t1"));
|
||||
EXPECT_EQ(f->get_results().size(), 2);
|
||||
EXPECT_NO_THROW(f->add_output("relu_t1"));
|
||||
EXPECT_NO_THROW(out2 = f->add_output("relu_t1"));
|
||||
EXPECT_EQ(f->get_results().size(), 2);
|
||||
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
|
||||
EXPECT_EQ(out1, out2);
|
||||
EXPECT_EQ(out1.get_node(), f->get_results()[1].get());
|
||||
}
|
||||
|
||||
TEST(function, add_output_op_name) {
|
||||
@ -880,7 +883,9 @@ TEST(function, add_output_port) {
|
||||
|
||||
EXPECT_EQ(f->get_results().size(), 1);
|
||||
|
||||
EXPECT_NO_THROW(f->add_output(relu1->output(0)));
|
||||
ov::Output<ov::Node> out;
|
||||
EXPECT_NO_THROW(out = f->add_output(relu1->output(0)));
|
||||
EXPECT_EQ(out.get_node(), f->get_results()[1].get());
|
||||
EXPECT_EQ(f->get_results().size(), 2);
|
||||
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
|
||||
}
|
||||
@ -966,8 +971,10 @@ TEST(function, add_output_port_to_result) {
|
||||
|
||||
EXPECT_EQ(f->get_results().size(), 1);
|
||||
|
||||
EXPECT_NO_THROW(f->add_output(result->output(0)));
|
||||
ov::Output<ov::Node> out;
|
||||
EXPECT_NO_THROW(out = f->add_output(result->output(0)));
|
||||
EXPECT_EQ(f->get_results().size(), 1);
|
||||
EXPECT_EQ(out, result->output(0));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -589,6 +589,7 @@ void regclass_graph_Function(py::module m) {
|
||||
"add_outputs",
|
||||
[](ov::Function& self, py::handle& outputs) {
|
||||
int i = 0;
|
||||
std::vector<ov::Output<ov::Node>> new_outputs;
|
||||
py::list _outputs;
|
||||
if (!py::isinstance<py::list>(outputs)) {
|
||||
if (py::isinstance<py::str>(outputs)) {
|
||||
@ -605,19 +606,22 @@ void regclass_graph_Function(py::module m) {
|
||||
}
|
||||
|
||||
for (py::handle output : _outputs) {
|
||||
ov::Output<ov::Node> out;
|
||||
if (py::isinstance<py::str>(_outputs[i])) {
|
||||
self.add_output(output.cast<std::string>());
|
||||
out = self.add_output(output.cast<std::string>());
|
||||
} else if (py::isinstance<py::tuple>(output)) {
|
||||
py::tuple output_tuple = output.cast<py::tuple>();
|
||||
self.add_output(output_tuple[0].cast<std::string>(), output_tuple[1].cast<int>());
|
||||
out = self.add_output(output_tuple[0].cast<std::string>(), output_tuple[1].cast<int>());
|
||||
} else if (py::isinstance<ov::Output<ov::Node>>(_outputs[i])) {
|
||||
self.add_output(output.cast<ov::Output<ov::Node>>());
|
||||
out = self.add_output(output.cast<ov::Output<ov::Node>>());
|
||||
} else {
|
||||
throw py::type_error("Incorrect type of a value to add as output at index " + std::to_string(i) +
|
||||
".");
|
||||
}
|
||||
new_outputs.emplace_back(out);
|
||||
i++;
|
||||
}
|
||||
return new_outputs;
|
||||
},
|
||||
py::arg("outputs"));
|
||||
|
||||
|
@ -21,10 +21,13 @@ def test_function_add_outputs_tensor_name():
|
||||
relu2 = ops.relu(relu1, name="relu2")
|
||||
function = Function(relu2, [param], "TestFunction")
|
||||
assert len(function.get_results()) == 1
|
||||
function.add_outputs("relu_t1")
|
||||
new_outs = function.add_outputs("relu_t1")
|
||||
assert len(function.get_results()) == 2
|
||||
assert isinstance(function.outputs[1].get_tensor(), DescriptorTensor)
|
||||
assert "relu_t1" in function.outputs[1].get_tensor().names
|
||||
assert len(new_outs) == 1
|
||||
assert new_outs[0].get_node() == function.outputs[1].get_node()
|
||||
assert new_outs[0].get_index() == function.outputs[1].get_index()
|
||||
|
||||
|
||||
def test_function_add_outputs_op_name():
|
||||
@ -35,8 +38,11 @@ def test_function_add_outputs_op_name():
|
||||
relu2 = ops.relu(relu1, name="relu2")
|
||||
function = Function(relu2, [param], "TestFunction")
|
||||
assert len(function.get_results()) == 1
|
||||
function.add_outputs(("relu1", 0))
|
||||
new_outs = function.add_outputs(("relu1", 0))
|
||||
assert len(function.get_results()) == 2
|
||||
assert len(new_outs) == 1
|
||||
assert new_outs[0].get_node() == function.outputs[1].get_node()
|
||||
assert new_outs[0].get_index() == function.outputs[1].get_index()
|
||||
|
||||
|
||||
def test_function_add_output_port():
|
||||
@ -47,8 +53,11 @@ def test_function_add_output_port():
|
||||
relu2 = ops.relu(relu1, name="relu2")
|
||||
function = Function(relu2, [param], "TestFunction")
|
||||
assert len(function.get_results()) == 1
|
||||
function.add_outputs(relu1.output(0))
|
||||
new_outs = function.add_outputs(relu1.output(0))
|
||||
assert len(function.get_results()) == 2
|
||||
assert len(new_outs) == 1
|
||||
assert new_outs[0].get_node() == function.outputs[1].get_node()
|
||||
assert new_outs[0].get_index() == function.outputs[1].get_index()
|
||||
|
||||
|
||||
def test_function_add_output_incorrect_tensor_name():
|
||||
@ -102,8 +111,13 @@ def test_add_outputs_several_tensors():
|
||||
relu3 = ops.relu(relu2, name="relu3")
|
||||
function = Function(relu3, [param], "TestFunction")
|
||||
assert len(function.get_results()) == 1
|
||||
function.add_outputs(["relu_t1", "relu_t2"])
|
||||
new_outs = function.add_outputs(["relu_t1", "relu_t2"])
|
||||
assert len(function.get_results()) == 3
|
||||
assert len(new_outs) == 2
|
||||
assert new_outs[0].get_node() == function.outputs[1].get_node()
|
||||
assert new_outs[0].get_index() == function.outputs[1].get_index()
|
||||
assert new_outs[1].get_node() == function.outputs[2].get_node()
|
||||
assert new_outs[1].get_index() == function.outputs[2].get_index()
|
||||
|
||||
|
||||
def test_add_outputs_several_ports():
|
||||
@ -116,8 +130,13 @@ def test_add_outputs_several_ports():
|
||||
relu3 = ops.relu(relu2, name="relu3")
|
||||
function = Function(relu3, [param], "TestFunction")
|
||||
assert len(function.get_results()) == 1
|
||||
function.add_outputs([("relu1", 0), ("relu2", 0)])
|
||||
new_outs = function.add_outputs([("relu1", 0), ("relu2", 0)])
|
||||
assert len(function.get_results()) == 3
|
||||
assert len(new_outs) == 2
|
||||
assert new_outs[0].get_node() == function.outputs[1].get_node()
|
||||
assert new_outs[0].get_index() == function.outputs[1].get_index()
|
||||
assert new_outs[1].get_node() == function.outputs[2].get_node()
|
||||
assert new_outs[1].get_index() == function.outputs[2].get_index()
|
||||
|
||||
|
||||
def test_add_outputs_incorrect_value():
|
||||
|
Loading…
Reference in New Issue
Block a user