Fix return values for lift_up_through func (#7323)

* Fix return valuese for lift_up_through func

* Update unit test

* Refactoring code according to code review

* Fix revers outputs

* Fix unit test

* Fix comment

* Add multioutput support

* Add unit test for cace with several output from ReverseChannel op

* Fix distinantion connect
This commit is contained in:
iliya mironov
2021-09-08 14:48:59 +03:00
committed by GitHub
parent 8bd41a1f45
commit 60714ce40a
2 changed files with 36 additions and 14 deletions

View File

@@ -114,7 +114,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
returns boolean value whatever we should continue propagating current ReverseChannels operation down or not
"""
# detaching reverse_channels node from the graph
if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0)\
if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0) \
and node.is_out_port_connected(0):
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
@@ -137,7 +137,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
ReverseChannels weights previous_op ReverseChannels
\ / \ /
Conv Conv
For grouped convolution:
BEFORE AFTER
@@ -295,12 +295,11 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc),
}
@staticmethod
def lift_up_through(node: Node, reverse_channels: Node):
def lift_up_through_pad(node: Node, reverse_channels: Node):
r"""
BEFORE AFTER
@@ -308,25 +307,29 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
\
previous_op previous_op ReverseChannels previous_op
\ / \ /
Node Node
Pad Pad
| |
ReverseChannels next_op
|
next_op
returns boolean value whatever we should continue propagating current ReverseChannels operation up or not
returns two objects:
first - boolean value whatever we should continue propagating current ReverseChannels operation up or not
second - list of ReverseChannels operations that were produced while propagating reverse_channels up
"""
if node.is_in_port_connected(0):
node_input_port_0 = node.in_port(0)
reverse_channels_out_npde = reverse_channels.out_port(0).get_connection().get_destination().node
reverse_channels_out_nodes = reverse_channels.out_port(0).get_connection().get_destinations()
reverse_channels.out_port(0).disconnect()
reverse_channels.in_port(0).disconnect()
src = node_input_port_0.get_connection().get_source()
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
src.connect(reverse_channels.in_port(0))
node.out_port(0).get_connection().set_destination(reverse_channels_out_npde.in_port(0))
return True
return False
for reverse_channels_destination in reverse_channels_out_nodes:
node.out_port(0).get_connection().add_destination(reverse_channels_destination)
return True, [reverse_channels]
return False, []
@staticmethod
def lift_up_through_eltwise(node: Node, reverse_channels: Node):

View File

@@ -32,6 +32,7 @@ nodes2 = {
**regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}),
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
**result('result'),
**result('result2'),
}
class ReverseInputChannelsTest(unittest.TestCase):
@@ -64,7 +65,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
ReverseChannelsPropagationUp.lift_up_through_eltwise(node, reverse_channels)
self.check_graph_attrs(graph, ['placeholder1', 'placeholder2'])
def test_lift_up_through(self):
def test_lift_up_through_pad(self):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
@@ -74,7 +75,25 @@ class ReverseInputChannelsTest(unittest.TestCase):
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')
ReverseChannelsPropagationUp.lift_up_through(node, reverse_channels)
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])
def test_lift_up_through_pad2(self):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
*connect('reverse_channels:0', '0:result'), *connect('reverse_channels:0', '0:result2')])
self.set_graph_attrs(graph, ['placeholder'])
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])