Fix of ReverseInputChannels for NHWC layout. (#8504)
* Fixed ReverseInputChannels for new api. * Renamed function, added node name in assert, code refactoring. * Added ReverseChannels propagation through transposes. * Refactored get_fw_index() method. * Fixed wrong name. * Added negative axis support.
This commit is contained in:
parent
925e24525a
commit
f800993e6f
@ -50,10 +50,38 @@ class InsertReverseChannels(BackReplacementPattern):
|
||||
"""
|
||||
enabled = False
|
||||
|
||||
@staticmethod
|
||||
def get_fw_index(node: Node, idx: int):
|
||||
if not node.has_valid('rt_info'):
|
||||
return idx
|
||||
|
||||
rt_info = node.rt_info
|
||||
if not rt_info.contains('old_api_map'):
|
||||
return idx
|
||||
|
||||
old_api_map_version = rt_info.get_attribute_version('old_api_map')
|
||||
old_api_map = rt_info.info['old_api_map', old_api_map_version]
|
||||
if 'inverse_order' not in old_api_map.info:
|
||||
return idx
|
||||
|
||||
order = old_api_map.info['inverse_order']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
if idx < 0:
|
||||
assert not node.out_port(0).disconnected(), 'Cannot normalize negative axis {} in node {} ' \
|
||||
'as out port is disconnected.'.format(idx, node_name)
|
||||
data_rank = len(list(node.out_port(0).data.get_shape()))
|
||||
idx = data_rank + idx
|
||||
|
||||
assert len(order) > idx >= 0, \
|
||||
'Channel index {} is incompatible with old_api_map in node {}.'.format(idx, node_name)
|
||||
return list(order).index(idx)
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
|
||||
for p in graph.get_op_nodes(type='Parameter')]
|
||||
suitable_params = [(name, p, shape) for name, p, shape in all_params if len(shape) == 4 and shape[1] == 3]
|
||||
suitable_params = [(name, p, shape) for name, p, shape in all_params if
|
||||
len(shape) == 4 and shape[self.get_fw_index(p, 1)] == 3]
|
||||
|
||||
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
|
||||
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape in suitable_params}))
|
||||
@ -66,10 +94,14 @@ class InsertReverseChannels(BackReplacementPattern):
|
||||
extra={'is_warning': True})
|
||||
|
||||
for name, parameter, _ in suitable_params:
|
||||
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels'}).create_node()
|
||||
parameter.out_port(0).get_connection().set_source(reverse_channels.out_port(0),
|
||||
attributes_save_mode='source')
|
||||
parameter.out_port(0).connect(reverse_channels.in_port(0))
|
||||
reverse_index = self.get_fw_index(parameter, 1)
|
||||
|
||||
if parameter.out_port(0).disconnected():
|
||||
continue
|
||||
|
||||
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels',
|
||||
'axis': reverse_index}).create_node()
|
||||
parameter.out_port(0).get_connection().insert_node(reverse_channels, attributes_save_mode='source')
|
||||
|
||||
|
||||
class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||
@ -95,11 +127,30 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||
'Shape': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
|
||||
'ShapeOf': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
|
||||
|
||||
'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through(node, rc),
|
||||
'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, rc),
|
||||
'Transpose': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_transpose(node, rc),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def pass_rc_through(node: Node, reverse_channels: Node):
|
||||
def pass_rc_through_transpose(node: Node, reverse_channels: Node):
|
||||
if node.in_port(1).disconnected() or node.in_port(0).disconnected():
|
||||
return False
|
||||
order = node.in_port(1).data.get_value()
|
||||
reverse_axis = reverse_channels.axis
|
||||
data_rank = len(list(node.in_port(0).data.get_shape()))
|
||||
|
||||
if reverse_axis < 0:
|
||||
reverse_axis = data_rank + reverse_axis
|
||||
assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels)
|
||||
|
||||
if order is None:
|
||||
return False
|
||||
new_axis = list(order).index(reverse_axis)
|
||||
reverse_channels.axis = new_axis
|
||||
return ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
|
||||
|
||||
@staticmethod
|
||||
def pass_rc_through_zero_port_only(node: Node, reverse_channels: Node):
|
||||
r"""
|
||||
BEFORE AFTER
|
||||
|
||||
@ -295,11 +346,30 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
||||
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
|
||||
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc),
|
||||
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, rc),
|
||||
'Transpose': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_transpose(node, rc),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def lift_up_through_pad(node: Node, reverse_channels: Node):
|
||||
def lift_up_through_transpose(node: Node, reverse_channels: Node):
|
||||
if node.in_port(1).disconnected() or node.in_port(0).disconnected():
|
||||
return False
|
||||
order = node.in_port(1).data.get_value()
|
||||
reverse_axis = reverse_channels.axis
|
||||
data_rank = len(list(node.in_port(0).data.get_shape()))
|
||||
|
||||
if reverse_axis < 0:
|
||||
reverse_axis = data_rank + reverse_axis
|
||||
assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels)
|
||||
|
||||
if order is None:
|
||||
return False
|
||||
new_axis = order[reverse_axis]
|
||||
reverse_channels.axis = new_axis
|
||||
return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||
|
||||
@staticmethod
|
||||
def lift_up_through_zero_port_only(node: Node, reverse_channels: Node):
|
||||
r"""
|
||||
BEFORE AFTER
|
||||
|
||||
@ -307,7 +377,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
||||
\
|
||||
previous_op previous_op ReverseChannels previous_op
|
||||
\ / \ /
|
||||
Pad Pad
|
||||
Node Node
|
||||
| |
|
||||
ReverseChannels next_op
|
||||
|
|
||||
@ -323,7 +393,16 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
|
||||
reverse_channels.out_port(0).disconnect()
|
||||
reverse_channels.in_port(0).disconnect()
|
||||
src = node_input_port_0.get_connection().get_source()
|
||||
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
|
||||
|
||||
if src.node.soft_get('type') == 'Parameter':
|
||||
# For Parameter nodes tensor debug attributes should not move to the last node
|
||||
# of subgraph. It is needed for the proper mapping of input framework name.
|
||||
# For this reason "source" mode is used to keep tensor debug attributes at Parameter node.
|
||||
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0),
|
||||
attributes_save_mode="source")
|
||||
else:
|
||||
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
|
||||
|
||||
src.connect(reverse_channels.in_port(0))
|
||||
for reverse_channels_destination in reverse_channels_out_nodes:
|
||||
node.out_port(0).get_connection().add_destination(reverse_channels_destination)
|
||||
|
@ -29,6 +29,17 @@ class RTInfo:
|
||||
"""
|
||||
self.info = defaultdict(dict)
|
||||
|
||||
def contains(self, attribute_name: str):
|
||||
attr_count = [key[0] for key in list(self.info.keys())].count(attribute_name)
|
||||
assert attr_count <= 1, 'Incorrect rt_info attribute, got more than one {}.'.format(attribute_name)
|
||||
return attr_count > 0
|
||||
|
||||
def get_attribute_version(self, attribute_name: str):
|
||||
for name, version in list(self.info.keys()):
|
||||
if name == attribute_name:
|
||||
return version
|
||||
raise Exception("rt_info does not contain attribute with name {}".format(attribute_name))
|
||||
|
||||
|
||||
class RTInfoElement:
|
||||
"""
|
||||
|
@ -3,13 +3,16 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown
|
||||
from mo.graph.graph import Node, Graph
|
||||
from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data
|
||||
from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown, \
|
||||
InsertReverseChannels
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.runtime_info import OldAPIMap, RTInfo
|
||||
from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter', 'rt_info': RTInfo()}),
|
||||
**regular_op_with_shaped_data('placeholder2', [1, 1, 1, 1], {'type': 'Parameter'}),
|
||||
|
||||
**regular_op_with_shaped_data('mul', [1, 3, 10, 10], {'type': 'Multiply'}),
|
||||
@ -35,6 +38,17 @@ nodes2 = {
|
||||
**result('result2'),
|
||||
}
|
||||
|
||||
nodes3 = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('transpose', [1, 3, 10, 10], {'type': 'Transpose'}),
|
||||
**valued_const_with_data('transpose_order', int64_array([0, 3, 1, 2])),
|
||||
**regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 3}),
|
||||
**regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
|
||||
**result('result'),
|
||||
**result('result2'),
|
||||
}
|
||||
|
||||
|
||||
class ReverseInputChannelsTest(unittest.TestCase):
|
||||
def check_graph_attrs(self, graph: Graph, parameter_node_names: list):
|
||||
for node in graph.get_op_nodes():
|
||||
@ -75,12 +89,11 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
node = Node(graph, 'pad')
|
||||
reverse_channels = Node(graph, 'reverse_channels')
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
|
||||
def test_lift_up_through_pad2(self):
|
||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
|
||||
@ -91,12 +104,11 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
node = Node(graph, 'pad')
|
||||
reverse_channels = Node(graph, 'reverse_channels')
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
|
||||
def test_pass_rc_through(self):
|
||||
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
|
||||
*connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'),
|
||||
@ -107,5 +119,114 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
node = Node(graph, 'pad')
|
||||
reverse_channels = Node(graph, 'reverse_channels')
|
||||
|
||||
ReverseChannelsPropagationDown.pass_rc_through(node, reverse_channels)
|
||||
ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
def test_lift_up_through_transpose(self):
|
||||
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
|
||||
*connect('transpose', 'reverse_channels_down'),
|
||||
*connect('reverse_channels_down', 'result')])
|
||||
graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'),
|
||||
*connect('transpose_order', '1:transpose'),
|
||||
*connect('reverse_channels_down', 'transpose'),
|
||||
*connect('transpose', 'result')])
|
||||
self.set_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
node = Node(graph, 'transpose')
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 3)
|
||||
|
||||
def test_lift_down_through_transpose(self):
|
||||
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
|
||||
*connect('transpose_order', '1:transpose'),
|
||||
*connect('reverse_channels_up', '0:transpose'),
|
||||
*connect('transpose', 'result')])
|
||||
graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'),
|
||||
*connect('transpose_order', '1:transpose'),
|
||||
*connect('transpose', 'reverse_channels_up'),
|
||||
*connect('reverse_channels_up', '0:result')])
|
||||
self.set_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
node = Node(graph, 'transpose')
|
||||
reverse_channels = Node(graph, 'reverse_channels_up')
|
||||
|
||||
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
|
||||
|
||||
self.assertTrue(keep_moving_down is True)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 1)
|
||||
|
||||
def test_lift_up_through_transpose_negative_axis(self):
|
||||
graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'),
|
||||
*connect('transpose', 'reverse_channels_down'),
|
||||
*connect('reverse_channels_down', 'result')])
|
||||
graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'),
|
||||
*connect('transpose_order', '1:transpose'),
|
||||
*connect('reverse_channels_down', 'transpose'),
|
||||
*connect('transpose', 'result')])
|
||||
self.set_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
node = Node(graph, 'transpose')
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
reverse_channels.axis = -3
|
||||
|
||||
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels)
|
||||
self.assertTrue(keep_moving_up is True)
|
||||
self.assertTrue(len(new_reverses) == 1)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 3)
|
||||
|
||||
def test_lift_down_through_transpose_negative_axis(self):
|
||||
graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'),
|
||||
*connect('transpose_order', '1:transpose'),
|
||||
*connect('reverse_channels_up', '0:transpose'),
|
||||
*connect('transpose', 'result')])
|
||||
graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'),
|
||||
*connect('transpose_order', '1:transpose'),
|
||||
*connect('transpose', 'reverse_channels_up'),
|
||||
*connect('reverse_channels_up', '0:result')])
|
||||
self.set_graph_attrs(graph, ['placeholder'])
|
||||
|
||||
node = Node(graph, 'transpose')
|
||||
reverse_channels = Node(graph, 'reverse_channels_up')
|
||||
reverse_channels.axis = -1
|
||||
|
||||
keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels)
|
||||
|
||||
self.assertTrue(keep_moving_down is True)
|
||||
self.check_graph_attrs(graph, ['placeholder'])
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 1)
|
||||
|
||||
def test_get_fw_index(self):
|
||||
graph = build_graph(nodes, [*connect('placeholder1', 'result')])
|
||||
node = Node(graph, 'placeholder1')
|
||||
old_api_map = OldAPIMap(version=0)
|
||||
node.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
|
||||
node.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1])
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, 0) == 0)
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, 1) == 3)
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1)
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2)
|
||||
self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user