Fix return values for lift_up_through func (#7323)
* Fix return valuese for lift_up_through func * Update unit test * Refactoring code according to code review * Fix revers outputs * Fix unit test * Fix comment * Add multioutput support * Add unit test for cace with several output from ReverseChannel op * Fix distinantion connect
This commit is contained in:
@@ -114,7 +114,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||
returns boolean value whatever we should continue propagating current ReverseChannels operation down or not
|
||||
"""
|
||||
# detaching reverse_channels node from the graph
|
||||
if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0)\
|
||||
if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0) \
|
||||
and node.is_out_port_connected(0):
|
||||
reverse_channels.out_port(0).get_connection().set_source(
|
||||
reverse_channels.in_port(0).get_connection().get_source())
|
||||
@@ -137,7 +137,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||
ReverseChannels weights previous_op ReverseChannels
|
||||
\ / \ /
|
||||
Conv Conv
|
||||
|
||||
|
||||
For grouped convolution:
|
||||
BEFORE AFTER
|
||||
|
||||
@@ -295,12 +295,11 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
||||
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||
|
||||
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through(node, rc),
|
||||
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def lift_up_through(node: Node, reverse_channels: Node):
|
||||
def lift_up_through_pad(node: Node, reverse_channels: Node):
|
||||
r"""
|
||||
BEFORE AFTER
|
||||
|
||||
@@ -308,25 +307,29 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
||||
\
|
||||
previous_op previous_op ReverseChannels previous_op
|
||||
\ / \ /
|
||||
Node Node
|
||||
Pad Pad
|
||||
| |
|
||||
ReverseChannels next_op
|
||||
|
|
||||
next_op
|
||||
|
||||
returns boolean value whatever we should continue propagating current ReverseChannels operation up or not
|
||||
returns two objects:
|
||||
first - boolean value whatever we should continue propagating current ReverseChannels operation up or not
|
||||
second - list of ReverseChannels operations that were produced while propagating reverse_channels up
|
||||
"""
|
||||
if node.is_in_port_connected(0):
|
||||
node_input_port_0 = node.in_port(0)
|
||||
reverse_channels_out_npde = reverse_channels.out_port(0).get_connection().get_destination().node
|
||||
reverse_channels_out_nodes = reverse_channels.out_port(0).get_connection().get_destinations()
|
||||
reverse_channels.out_port(0).disconnect()
|
||||
|
||||
reverse_channels.in_port(0).disconnect()
|
||||
src = node_input_port_0.get_connection().get_source()
|
||||
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
|
||||
src.connect(reverse_channels.in_port(0))
|
||||
node.out_port(0).get_connection().set_destination(reverse_channels_out_npde.in_port(0))
|
||||
return True
|
||||
return False
|
||||
for reverse_channels_destination in reverse_channels_out_nodes:
|
||||
node.out_port(0).get_connection().add_destination(reverse_channels_destination)
|
||||
|
||||
return True, [reverse_channels]
|
||||
return False, []
|
||||
|
||||
@staticmethod
|
||||
def lift_up_through_eltwise(node: Node, reverse_channels: Node):
|
||||
|
||||
@@ -32,6 +32,7 @@ nodes2 = {
|
||||
**regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}),
|
||||
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
|
||||
**result('result'),
|
||||
**result('result2'),
|
||||
}
|
||||
|
||||
class ReverseInputChannelsTest(unittest.TestCase):
|
||||
@@ -64,7 +65,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
ReverseChannelsPropagationUp.lift_up_through_eltwise(node, reverse_channels)
|
||||
self.check_graph_attrs(graph, ['placeholder1', 'placeholder2'])
|
||||
|
||||
def test_lift_up_through(self):
|
||||
def test_lift_up_through_pad(self):
|
||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
|
||||
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
|
||||
@@ -74,7 +75,25 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
node = Node(graph, 'pad')
|
||||
reverse_channels = Node(graph, 'reverse_channels')
|
||||
|
||||
ReverseChannelsPropagationUp.lift_up_through(node, reverse_channels)
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
|
||||
def test_lift_up_through_pad2(self):
|
||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
|
||||
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
|
||||
*connect('reverse_channels:0', '0:result'), *connect('reverse_channels:0', '0:result2')])
|
||||
self.set_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
node = Node(graph, 'pad')
|
||||
reverse_channels = Node(graph, 'reverse_channels')
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user