Fixed axis type in ReverseInputChannels. (#8708)

* Fixed axis type.

* Added checks, corrected get_fw_index.

* Corrected get_fw_index(), removed checks.
This commit is contained in:
Anastasia Popova 2021-11-23 15:22:42 +03:00 committed by GitHub
parent 8c55c761c4
commit 5104bf8ce8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 17 deletions

View File

@ -51,7 +51,7 @@ class InsertReverseChannels(BackReplacementPattern):
enabled = False
@staticmethod
def get_fw_index(node: Node, idx: int):
def get_fw_index(node: Node, idx: int) -> int:
if not node.has_valid('rt_info'):
return idx
@ -94,7 +94,7 @@ class InsertReverseChannels(BackReplacementPattern):
extra={'is_warning': True})
for name, parameter, _ in suitable_params:
reverse_index = self.get_fw_index(parameter, 1)
reverse_index = int64_array(self.get_fw_index(parameter, 1))
if parameter.out_port(0).disconnected():
continue
@ -137,6 +137,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
return False
order = node.in_port(1).data.get_value()
reverse_axis = reverse_channels.axis
data_rank = len(list(node.in_port(0).data.get_shape()))
if reverse_axis < 0:
@ -146,7 +147,7 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
if order is None:
return False
new_axis = list(order).index(reverse_axis)
reverse_channels.axis = new_axis
reverse_channels.axis = int64_array(new_axis)
return ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
@staticmethod
@ -356,6 +357,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
return False
order = node.in_port(1).data.get_value()
reverse_axis = reverse_channels.axis
data_rank = len(list(node.in_port(0).data.get_shape()))
if reverse_axis < 0:
@ -365,7 +367,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
if order is None:
return False
new_axis = order[reverse_axis]
reverse_channels.axis = new_axis
reverse_channels.axis = int64_array(new_axis)
return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
@staticmethod

View File

@ -3,6 +3,8 @@
import unittest
import numpy as np
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown, \
InsertReverseChannels
from mo.front.common.partial_infer.utils import int64_array, float32_array
@ -16,15 +18,14 @@ nodes = {
**regular_op_with_shaped_data('placeholder2', [1, 1, 1, 1], {'type': 'Parameter'}),
**regular_op_with_shaped_data('mul', [1, 3, 10, 10], {'type': 'Multiply'}),
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10],
{'type': 'ReverseChannels', 'axis': int64_array(1)}),
**regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}),
**result('result'),
}
nodes2 = {
**regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}),
@ -33,7 +34,8 @@ nodes2 = {
**valued_const_with_data('pad_const_1', int64_array([0, 0, 0, 0])),
**valued_const_with_data('pad_const_2', int64_array([0, 0, 1, 1])),
**regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}),
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10],
{'type': 'ReverseChannels', 'axis': int64_array(1)}),
**result('result'),
**result('result2'),
}
@ -42,8 +44,10 @@ nodes3 = {
**regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}),
**regular_op_with_shaped_data('transpose', [1, 3, 10, 10], {'type': 'Transpose'}),
**valued_const_with_data('transpose_order', int64_array([0, 3, 1, 2])),
**regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 3}),
**regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
**regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10],
{'type': 'ReverseChannels', 'axis': int64_array(3)}),
**regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10],
{'type': 'ReverseChannels', 'axis': int64_array(1)}),
**result('result'),
**result('result2'),
}
@ -89,7 +93,8 @@ class ReverseInputChannelsTest(unittest.TestCase):
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node,
reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])
@ -98,20 +103,22 @@ class ReverseInputChannelsTest(unittest.TestCase):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
*connect('reverse_channels:0', '0:result'), *connect('reverse_channels:0', '0:result2')])
*connect('reverse_channels:0', '0:result'),
*connect('reverse_channels:0', '0:result2')])
self.set_graph_attrs(graph, ['placeholder'])
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node,
reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])
def test_pass_rc_through(self):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
*connect('pad_const_1', '1:pad'), *connect('pad_const_2', '2:pad'),
*connect('pad', 'result')])
self.set_graph_attrs(graph, ['placeholder'])
@ -144,6 +151,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 3)
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
def test_lift_down_through_transpose(self):
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
@ -168,6 +176,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 1)
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
def test_lift_up_through_transpose_negative_axis(self):
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
@ -181,7 +190,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
node = Node(graph, 'transpose')
reverse_channels = Node(graph, 'reverse_channels_down')
reverse_channels.axis = -3
reverse_channels.axis = int64_array(-3)
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
@ -192,6 +201,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 3)
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
def test_lift_down_through_transpose_negative_axis(self):
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
@ -206,7 +216,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
node = Node(graph, 'transpose')
reverse_channels = Node(graph, 'reverse_channels_up')
reverse_channels.axis = -1
reverse_channels.axis = int64_array(-1)
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
@ -217,6 +227,7 @@ class ReverseInputChannelsTest(unittest.TestCase):
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 1)
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
def test_get_fw_index(self):
graph = build_graph(nodes, [*connect('placeholder1', 'result')])
@ -229,4 +240,4 @@ class ReverseInputChannelsTest(unittest.TestCase):
self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1)
self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2)
self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1)
self.assertTrue(type(InsertReverseChannels.get_fw_index(node, 0)) == int)