Fix of ReverseInputChannels for NHWC layout. (#8504)

* Fixed ReverseInputChannels for new api.

* Renamed function, added node name in assert, code refactoring.

* Added ReverseChannels propagation through transposes.

* Refactored get_fw_index() method.

* Fixed wrong name.

* Added negative axis support.
This commit is contained in:
Anastasia Popova 2021-11-15 15:24:46 +03:00 committed by GitHub
parent 925e24525a
commit f800993e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 231 additions and 20 deletions

View File

@ -50,10 +50,38 @@ class InsertReverseChannels(BackReplacementPattern):
"""
enabled = False
@staticmethod
def get_fw_index(node: Node, idx: int):
if not node.has_valid('rt_info'):
return idx
rt_info = node.rt_info
if not rt_info.contains('old_api_map'):
return idx
old_api_map_version = rt_info.get_attribute_version('old_api_map')
old_api_map = rt_info.info['old_api_map', old_api_map_version]
if 'inverse_order' not in old_api_map.info:
return idx
order = old_api_map.info['inverse_order']
node_name = node.soft_get('name', node.id)
if idx < 0:
assert not node.out_port(0).disconnected(), 'Cannot normalize negative axis {} in node {} ' \
'as out port is disconnected.'.format(idx, node_name)
data_rank = len(list(node.out_port(0).data.get_shape()))
idx = data_rank + idx
assert len(order) > idx >= 0, \
'Channel index {} is incompatible with old_api_map in node {}.'.format(idx, node_name)
return list(order).index(idx)
def find_and_replace_pattern(self, graph: Graph):
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
for p in graph.get_op_nodes(type='Parameter')]
suitable_params = [(name, p, shape) for name, p, shape in all_params if len(shape) == 4 and shape[1] == 3]
suitable_params = [(name, p, shape) for name, p, shape in all_params if
len(shape) == 4 and shape[self.get_fw_index(p, 1)] == 3]
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape in suitable_params}))
@ -66,10 +94,14 @@ class InsertReverseChannels(BackReplacementPattern):
extra={'is_warning': True})
for name, parameter, _ in suitable_params:
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels'}).create_node()
parameter.out_port(0).get_connection().set_source(reverse_channels.out_port(0),
attributes_save_mode='source')
parameter.out_port(0).connect(reverse_channels.in_port(0))
reverse_index = self.get_fw_index(parameter, 1)
if parameter.out_port(0).disconnected():
continue
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels',
'axis': reverse_index}).create_node()
parameter.out_port(0).get_connection().insert_node(reverse_channels, attributes_save_mode='source')
class ReverseChannelsPropagationDown(BackReplacementPattern):
@ -95,11 +127,30 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
'Shape': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
'ShapeOf': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, rc),
'Transpose': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_transpose(node, rc),
}
@staticmethod
def pass_rc_through(node: Node, reverse_channels: Node):
def pass_rc_through_transpose(node: Node, reverse_channels: Node):
if node.in_port(1).disconnected() or node.in_port(0).disconnected():
return False
order = node.in_port(1).data.get_value()
reverse_axis = reverse_channels.axis
data_rank = len(list(node.in_port(0).data.get_shape()))
if reverse_axis < 0:
reverse_axis = data_rank + reverse_axis
assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels)
if order is None:
return False
new_axis = list(order).index(reverse_axis)
reverse_channels.axis = new_axis
return ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
@staticmethod
def pass_rc_through_zero_port_only(node: Node, reverse_channels: Node):
r"""
BEFORE AFTER
@ -295,11 +346,30 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, rc),
'Transpose': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_transpose(node, rc),
}
@staticmethod
def lift_up_through_pad(node: Node, reverse_channels: Node):
def lift_up_through_transpose(node: Node, reverse_channels: Node):
if node.in_port(1).disconnected() or node.in_port(0).disconnected():
return False
order = node.in_port(1).data.get_value()
reverse_axis = reverse_channels.axis
data_rank = len(list(node.in_port(0).data.get_shape()))
if reverse_axis < 0:
reverse_axis = data_rank + reverse_axis
assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels)
if order is None:
return False
new_axis = order[reverse_axis]
reverse_channels.axis = new_axis
return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
@staticmethod
def lift_up_through_zero_port_only(node: Node, reverse_channels: Node):
r"""
BEFORE AFTER
@ -307,7 +377,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
\
previous_op previous_op ReverseChannels previous_op
\ / \ /
Pad Pad
Node Node
| |
ReverseChannels next_op
|
@ -323,7 +393,16 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
reverse_channels.out_port(0).disconnect()
reverse_channels.in_port(0).disconnect()
src = node_input_port_0.get_connection().get_source()
if src.node.soft_get('type') == 'Parameter':
# For Parameter nodes tensor debug attributes should not move to the last node
# of subgraph. It is needed for the proper mapping of input framework name.
# For this reason "source" mode is used to keep tensor debug attributes at Parameter node.
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0),
attributes_save_mode="source")
else:
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
src.connect(reverse_channels.in_port(0))
for reverse_channels_destination in reverse_channels_out_nodes:
node.out_port(0).get_connection().add_destination(reverse_channels_destination)

View File

@ -29,6 +29,17 @@ class RTInfo:
"""
self.info = defaultdict(dict)
def contains(self, attribute_name: str):
attr_count = [key[0] for key in list(self.info.keys())].count(attribute_name)
assert attr_count <= 1, 'Incorrect rt_info attribute, got more than one {}.'.format(attribute_name)
return attr_count > 0
def get_attribute_version(self, attribute_name: str):
for name, version in list(self.info.keys()):
if name == attribute_name:
return version
raise Exception("rt_info does not contain attribute with name {}".format(attribute_name))
class RTInfoElement:
"""

View File

@ -3,13 +3,16 @@
import unittest
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown
from mo.graph.graph import Node, Graph
from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown, \
InsertReverseChannels
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.graph.graph import Node, Graph
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.runtime_info import OldAPIMap, RTInfo
from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data
nodes = {
**regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter'}),
**regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter', 'rt_info': RTInfo()}),
**regular_op_with_shaped_data('placeholder2', [1, 1, 1, 1], {'type': 'Parameter'}),
**regular_op_with_shaped_data('mul', [1, 3, 10, 10], {'type': 'Multiply'}),
@ -35,6 +38,17 @@ nodes2 = {
**result('result2'),
}
nodes3 = {
**regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}),
**regular_op_with_shaped_data('transpose', [1, 3, 10, 10], {'type': 'Transpose'}),
**valued_const_with_data('transpose_order', int64_array([0, 3, 1, 2])),
**regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 3}),
**regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
**result('result'),
**result('result2'),
}
class ReverseInputChannelsTest(unittest.TestCase):
def check_graph_attrs(self, graph: Graph, parameter_node_names: list):
for node in graph.get_op_nodes():
@ -75,12 +89,11 @@ class ReverseInputChannelsTest(unittest.TestCase):
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])
def test_lift_up_through_pad2(self):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
@ -91,12 +104,11 @@ class ReverseInputChannelsTest(unittest.TestCase):
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])
def test_pass_rc_through(self):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
@ -107,5 +119,114 @@ class ReverseInputChannelsTest(unittest.TestCase):
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')
ReverseChannelsPropagationDown.pass_rc_through(node, reverse_channels)
ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
self.check_graph_attrs(graph, ['placeholder'])
def test_lift_up_through_transpose(self):
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
*connect('transpose', 'reverse_channels_down'),
*connect('reverse_channels_down', 'result')])
graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'),
*connect('transpose_order', '1:transpose'),
*connect('reverse_channels_down', 'transpose'),
*connect('transpose', 'result')])
self.set_graph_attrs(graph, ['placeholder'])
node = Node(graph, 'transpose')
reverse_channels = Node(graph, 'reverse_channels_down')
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 3)
def test_lift_down_through_transpose(self):
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
*connect('transpose_order', '1:transpose'),
*connect('reverse_channels_up', '0:transpose'),
*connect('transpose', 'result')])
graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'),
*connect('transpose_order', '1:transpose'),
*connect('transpose', 'reverse_channels_up'),
*connect('reverse_channels_up', '0:result')])
self.set_graph_attrs(graph, ['placeholder'])
node = Node(graph, 'transpose')
reverse_channels = Node(graph, 'reverse_channels_up')
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
self.assertTrue(keep_moving_down is True)
self.check_graph_attrs(graph, ['placeholder'])
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 1)
def test_lift_up_through_transpose_negative_axis(self):
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
*connect('transpose', 'reverse_channels_down'),
*connect('reverse_channels_down', 'result')])
graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'),
*connect('transpose_order', '1:transpose'),
*connect('reverse_channels_down', 'transpose'),
*connect('transpose', 'result')])
self.set_graph_attrs(graph, ['placeholder'])
node = Node(graph, 'transpose')
reverse_channels = Node(graph, 'reverse_channels_down')
reverse_channels.axis = -3
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 3)
def test_lift_down_through_transpose_negative_axis(self):
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
*connect('transpose_order', '1:transpose'),
*connect('reverse_channels_up', '0:transpose'),
*connect('transpose', 'result')])
graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'),
*connect('transpose_order', '1:transpose'),
*connect('transpose', 'reverse_channels_up'),
*connect('reverse_channels_up', '0:result')])
self.set_graph_attrs(graph, ['placeholder'])
node = Node(graph, 'transpose')
reverse_channels = Node(graph, 'reverse_channels_up')
reverse_channels.axis = -1
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
self.assertTrue(keep_moving_down is True)
self.check_graph_attrs(graph, ['placeholder'])
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
reverse_channels = Node(graph, 'reverse_channels_down')
self.assertTrue(reverse_channels.axis == 1)
def test_get_fw_index(self):
graph = build_graph(nodes, [*connect('placeholder1', 'result')])
node = Node(graph, 'placeholder1')
old_api_map = OldAPIMap(version=0)
node.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
node.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1])
self.assertTrue(InsertReverseChannels.get_fw_index(node, 0) == 0)
self.assertTrue(InsertReverseChannels.get_fw_index(node, 1) == 3)
self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1)
self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2)
self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1)