Fix incorrect edge removal in the disconnect() method (#4582)
* added condition to disconnect method * add unittest, rewrite the fix * revert the second implementation, update test Co-authored-by: yegor.kruglov <ykruglov@nnlvdp-mkaglins.inn.intel.com>
This commit is contained in:
parent
3656e1c564
commit
84cd802ca5
@ -320,7 +320,12 @@ class Port:
|
||||
self.node.graph.remove_edge(self.node.id, port.node.id)
|
||||
else:
|
||||
for port in consumer_ports:
|
||||
self.node.graph.remove_edge(port.node.in_node(port.idx).id, port.node.id)
|
||||
src_node = port.node.in_node(port.idx).id
|
||||
dst_node = port.node.id
|
||||
for key, val in self.node.graph.get_edge_data(src_node, dst_node).items():
|
||||
if val['in'] == port.idx:
|
||||
self.node.graph.remove_edge(src_node, dst_node, key=key)
|
||||
break
|
||||
else:
|
||||
source_port = self.get_source()
|
||||
if source_port is None:
|
||||
|
@ -127,3 +127,15 @@ class TestsGetTensorNames(unittest.TestCase):
|
||||
|
||||
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
||||
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2'])
|
||||
|
||||
|
||||
class TestPortMethods(unittest.TestCase):
|
||||
|
||||
def test_middle_disconnect_several_edges_between_two_nodes(self):
|
||||
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
|
||||
('Op1', 'Op1_data'), ('Op1_data', 'Op2', {'in': 0}), ('Op1_data', 'Op2', {'in': 1}),
|
||||
('Op1_data', 'Op2', {'in': 2})],
|
||||
nodes_with_edges_only=True)
|
||||
op1_node = Node(graph, 'Op1')
|
||||
op1_node.out_port(0).disconnect()
|
||||
self.assertTrue(op1_node.out_port(0).disconnected())
|
||||
|
Loading…
Reference in New Issue
Block a user