Fixed axis type in ReverseInputChannels. (#8708)
* Fixed axis type. * Added checks, corrected get_fw_index. * Corrected get_fw_index(), removed checks.
This commit is contained in:
parent
8c55c761c4
commit
5104bf8ce8
@ -51,7 +51,7 @@ class InsertReverseChannels(BackReplacementPattern):
|
||||
enabled = False
|
||||
|
||||
@staticmethod
|
||||
def get_fw_index(node: Node, idx: int):
|
||||
def get_fw_index(node: Node, idx: int) -> int:
|
||||
if not node.has_valid('rt_info'):
|
||||
return idx
|
||||
|
||||
@ -94,7 +94,7 @@ class InsertReverseChannels(BackReplacementPattern):
|
||||
extra={'is_warning': True})
|
||||
|
||||
for name, parameter, _ in suitable_params:
|
||||
reverse_index = self.get_fw_index(parameter, 1)
|
||||
reverse_index = int64_array(self.get_fw_index(parameter, 1))
|
||||
|
||||
if parameter.out_port(0).disconnected():
|
||||
continue
|
||||
@ -137,6 +137,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||
return False
|
||||
order = node.in_port(1).data.get_value()
|
||||
reverse_axis = reverse_channels.axis
|
||||
|
||||
data_rank = len(list(node.in_port(0).data.get_shape()))
|
||||
|
||||
if reverse_axis < 0:
|
||||
@ -146,7 +147,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||
if order is None:
|
||||
return False
|
||||
new_axis = list(order).index(reverse_axis)
|
||||
reverse_channels.axis = new_axis
|
||||
reverse_channels.axis = int64_array(new_axis)
|
||||
return ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
|
||||
|
||||
@staticmethod
|
||||
@ -356,6 +357,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
||||
return False
|
||||
order = node.in_port(1).data.get_value()
|
||||
reverse_axis = reverse_channels.axis
|
||||
|
||||
data_rank = len(list(node.in_port(0).data.get_shape()))
|
||||
|
||||
if reverse_axis < 0:
|
||||
@ -365,7 +367,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
||||
if order is None:
|
||||
return False
|
||||
new_axis = order[reverse_axis]
|
||||
reverse_channels.axis = new_axis
|
||||
reverse_channels.axis = int64_array(new_axis)
|
||||
return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||
|
||||
@staticmethod
|
||||
|
@ -3,6 +3,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown, \
|
||||
InsertReverseChannels
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
@ -16,15 +18,14 @@ nodes = {
|
||||
**regular_op_with_shaped_data('placeholder2', [1, 1, 1, 1], {'type': 'Parameter'}),
|
||||
|
||||
**regular_op_with_shaped_data('mul', [1, 3, 10, 10], {'type': 'Multiply'}),
|
||||
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
|
||||
|
||||
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10],
|
||||
{'type': 'ReverseChannels', 'axis': int64_array(1)}),
|
||||
|
||||
**regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}),
|
||||
|
||||
**result('result'),
|
||||
}
|
||||
|
||||
|
||||
nodes2 = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}),
|
||||
|
||||
@ -33,7 +34,8 @@ nodes2 = {
|
||||
**valued_const_with_data('pad_const_1', int64_array([0, 0, 0, 0])),
|
||||
**valued_const_with_data('pad_const_2', int64_array([0, 0, 1, 1])),
|
||||
**regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}),
|
||||
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
|
||||
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10],
|
||||
{'type': 'ReverseChannels', 'axis': int64_array(1)}),
|
||||
**result('result'),
|
||||
**result('result2'),
|
||||
}
|
||||
@ -42,8 +44,10 @@ nodes3 = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('transpose', [1, 3, 10, 10], {'type': 'Transpose'}),
|
||||
**valued_const_with_data('transpose_order', int64_array([0, 3, 1, 2])),
|
||||
**regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 3}),
|
||||
**regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
|
||||
**regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10],
|
||||
{'type': 'ReverseChannels', 'axis': int64_array(3)}),
|
||||
**regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10],
|
||||
{'type': 'ReverseChannels', 'axis': int64_array(1)}),
|
||||
**result('result'),
|
||||
**result('result2'),
|
||||
}
|
||||
@ -89,7 +93,8 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
node = Node(graph, 'pad')
|
||||
reverse_channels = Node(graph, 'reverse_channels')
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node,
|
||||
reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
@ -98,20 +103,22 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
|
||||
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
|
||||
*connect('reverse_channels:0', '0:result'), *connect('reverse_channels:0', '0:result2')])
|
||||
*connect('reverse_channels:0', '0:result'),
|
||||
*connect('reverse_channels:0', '0:result2')])
|
||||
self.set_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
node = Node(graph, 'pad')
|
||||
reverse_channels = Node(graph, 'reverse_channels')
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node,
|
||||
reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
def test_pass_rc_through(self):
|
||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
|
||||
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
|
||||
*connect('pad_const_1', '1:pad'), *connect('pad_const_2', '2:pad'),
|
||||
*connect('pad', 'result')])
|
||||
self.set_graph_attrs(graph, ['placeholder'])
|
||||
@ -144,6 +151,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 3)
|
||||
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
|
||||
|
||||
def test_lift_down_through_transpose(self):
|
||||
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
|
||||
@ -168,6 +176,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 1)
|
||||
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
|
||||
|
||||
def test_lift_up_through_transpose_negative_axis(self):
|
||||
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
|
||||
@ -181,7 +190,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
|
||||
node = Node(graph, 'transpose')
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
reverse_channels.axis = -3
|
||||
reverse_channels.axis = int64_array(-3)
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
@ -192,6 +201,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 3)
|
||||
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
|
||||
|
||||
def test_lift_down_through_transpose_negative_axis(self):
|
||||
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
|
||||
@ -206,7 +216,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
|
||||
node = Node(graph, 'transpose')
|
||||
reverse_channels = Node(graph, 'reverse_channels_up')
|
||||
reverse_channels.axis = -1
|
||||
reverse_channels.axis = int64_array(-1)
|
||||
|
||||
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
|
||||
|
||||
@ -217,6 +227,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 1)
|
||||
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
|
||||
|
||||
def test_get_fw_index(self):
|
||||
graph = build_graph(nodes, [*connect('placeholder1', 'result')])
|
||||
@ -229,4 +240,4 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1)
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2)
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1)
|
||||
|
||||
self.assertTrue(type(InsertReverseChannels.get_fw_index(node, 0)) == int)
|
||||
|
Loading…
Reference in New Issue
Block a user