Fix incorrect edge removal in the disconnect() method (#4582)

* added condition to disconnect method

* add unittest, rewrite the fix

* revert the second implementation, update test

Co-authored-by: yegor.kruglov <ykruglov@nnlvdp-mkaglins.inn.intel.com>
This commit is contained in:
Yegor Kruglov 2021-03-05 21:57:28 +03:00 committed by GitHub
parent 3656e1c564
commit 84cd802ca5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 1 deletions

View File

@ -320,7 +320,12 @@ class Port:
self.node.graph.remove_edge(self.node.id, port.node.id)
else:
for port in consumer_ports:
self.node.graph.remove_edge(port.node.in_node(port.idx).id, port.node.id)
src_node = port.node.in_node(port.idx).id
dst_node = port.node.id
for key, val in self.node.graph.get_edge_data(src_node, dst_node).items():
if val['in'] == port.idx:
self.node.graph.remove_edge(src_node, dst_node, key=key)
break
else:
source_port = self.get_source()
if source_port is None:

View File

@ -127,3 +127,15 @@ class TestsGetTensorNames(unittest.TestCase):
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2'])
class TestPortMethods(unittest.TestCase):
def test_middle_disconnect_several_edges_between_two_nodes(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
('Op1', 'Op1_data'), ('Op1_data', 'Op2', {'in': 0}), ('Op1_data', 'Op2', {'in': 1}),
('Op1_data', 'Op2', {'in': 2})],
nodes_with_edges_only=True)
op1_node = Node(graph, 'Op1')
op1_node.out_port(0).disconnect()
self.assertTrue(op1_node.out_port(0).disconnected())