diff --git a/model-optimizer/extensions/back/ReverseInputChannels.py b/model-optimizer/extensions/back/ReverseInputChannels.py index ce0cbb952ae..8aefd1bc152 100644 --- a/model-optimizer/extensions/back/ReverseInputChannels.py +++ b/model-optimizer/extensions/back/ReverseInputChannels.py @@ -50,10 +50,38 @@ class InsertReverseChannels(BackReplacementPattern): """ enabled = False + @staticmethod + def get_fw_index(node: Node, idx: int): + if not node.has_valid('rt_info'): + return idx + + rt_info = node.rt_info + if not rt_info.contains('old_api_map'): + return idx + + old_api_map_version = rt_info.get_attribute_version('old_api_map') + old_api_map = rt_info.info['old_api_map', old_api_map_version] + if 'inverse_order' not in old_api_map.info: + return idx + + order = old_api_map.info['inverse_order'] + node_name = node.soft_get('name', node.id) + + if idx < 0: + assert not node.out_port(0).disconnected(), 'Cannot normalize negative axis {} in node {} ' \ + 'as out port is disconnected.'.format(idx, node_name) + data_rank = len(list(node.out_port(0).data.get_shape())) + idx = data_rank + idx + + assert len(order) > idx >= 0, \ + 'Channel index {} is incompatible with old_api_map in node {}.'.format(idx, node_name) + return list(order).index(idx) + def find_and_replace_pattern(self, graph: Graph): all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape())) for p in graph.get_op_nodes(type='Parameter')] - suitable_params = [(name, p, shape) for name, p, shape in all_params if len(shape) == 4 and shape[1] == 3] + suitable_params = [(name, p, shape) for name, p, shape in all_params if + len(shape) == 4 and shape[self.get_fw_index(p, 1)] == 3] log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params})) log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape in suitable_params})) @@ -66,10 +94,14 @@ class InsertReverseChannels(BackReplacementPattern): extra={'is_warning': True}) for name, parameter, _ in suitable_params: - reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels'}).create_node() - parameter.out_port(0).get_connection().set_source(reverse_channels.out_port(0), - attributes_save_mode='source') - parameter.out_port(0).connect(reverse_channels.in_port(0)) + reverse_index = self.get_fw_index(parameter, 1) + + if parameter.out_port(0).disconnected(): + continue + + reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels', + 'axis': reverse_index}).create_node() + parameter.out_port(0).get_connection().insert_node(reverse_channels, attributes_save_mode='source') class ReverseChannelsPropagationDown(BackReplacementPattern): @@ -95,11 +127,30 @@ class ReverseChannelsPropagationDown(BackReplacementPattern): 'Shape': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc), 'ShapeOf': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc), - 'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through(node, rc), + 'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, rc), + 'Transpose': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_transpose(node, rc), } @staticmethod - def pass_rc_through(node: Node, reverse_channels: Node): + def pass_rc_through_transpose(node: Node, reverse_channels: Node): + if node.in_port(1).disconnected() or node.in_port(0).disconnected(): + return False + order = node.in_port(1).data.get_value() + reverse_axis = reverse_channels.axis + data_rank = len(list(node.in_port(0).data.get_shape())) + + if reverse_axis < 0: + reverse_axis = data_rank + reverse_axis + assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels) + + if order is None: + return False + new_axis = list(order).index(reverse_axis) + reverse_channels.axis = new_axis + return ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels) + + @staticmethod + def pass_rc_through_zero_port_only(node: Node, reverse_channels: Node): r""" BEFORE AFTER @@ -295,11 +346,30 @@ class ReverseChannelsPropagationUp(BackReplacementPattern): 'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), - 'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc), + 'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, rc), + 'Transpose': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_transpose(node, rc), } @staticmethod - def lift_up_through_pad(node: Node, reverse_channels: Node): + def lift_up_through_transpose(node: Node, reverse_channels: Node): + if node.in_port(1).disconnected() or node.in_port(0).disconnected(): + return False + order = node.in_port(1).data.get_value() + reverse_axis = reverse_channels.axis + data_rank = len(list(node.in_port(0).data.get_shape())) + + if reverse_axis < 0: + reverse_axis = data_rank + reverse_axis + assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels) + + if order is None: + return False + new_axis = order[reverse_axis] + reverse_channels.axis = new_axis + return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels) + + @staticmethod + def lift_up_through_zero_port_only(node: Node, reverse_channels: Node): r""" BEFORE AFTER @@ -307,7 +377,7 @@ class ReverseChannelsPropagationUp(BackReplacementPattern): \ previous_op previous_op ReverseChannels previous_op \ / \ / - Pad Pad + Node Node | | ReverseChannels next_op | @@ -323,7 +393,16 @@ class ReverseChannelsPropagationUp(BackReplacementPattern): reverse_channels.out_port(0).disconnect() reverse_channels.in_port(0).disconnect() src = node_input_port_0.get_connection().get_source() - node_input_port_0.get_connection().set_source(reverse_channels.out_port(0)) + + if src.node.soft_get('type') == 'Parameter': + # For Parameter nodes tensor debug attributes should not move to the last node + # of subgraph. It is needed for the proper mapping of input framework name. + # For this reason "source" mode is used to keep tensor debug attributes at Parameter node. + node_input_port_0.get_connection().set_source(reverse_channels.out_port(0), + attributes_save_mode="source") + else: + node_input_port_0.get_connection().set_source(reverse_channels.out_port(0)) + src.connect(reverse_channels.in_port(0)) for reverse_channels_destination in reverse_channels_out_nodes: node.out_port(0).get_connection().add_destination(reverse_channels_destination) diff --git a/model-optimizer/mo/utils/runtime_info.py b/model-optimizer/mo/utils/runtime_info.py index 1c61eb0e7fb..3941794686a 100644 --- a/model-optimizer/mo/utils/runtime_info.py +++ b/model-optimizer/mo/utils/runtime_info.py @@ -29,6 +29,17 @@ class RTInfo: """ self.info = defaultdict(dict) + def contains(self, attribute_name: str): + attr_count = [key[0] for key in list(self.info.keys())].count(attribute_name) + assert attr_count <= 1, 'Incorrect rt_info attribute, got more than one {}.'.format(attribute_name) + return attr_count > 0 + + def get_attribute_version(self, attribute_name: str): + for name, version in list(self.info.keys()): + if name == attribute_name: + return version + raise Exception("rt_info does not contain attribute with name {}".format(attribute_name)) + class RTInfoElement: """ diff --git a/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py b/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py index 634f3ea9aef..1b5c8777bc0 100644 --- a/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py +++ b/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py @@ -3,13 +3,16 @@ import unittest -from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown -from mo.graph.graph import Node, Graph -from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data +from extensions.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown, \ + InsertReverseChannels from mo.front.common.partial_infer.utils import int64_array, float32_array +from mo.graph.graph import Node, Graph +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.runtime_info import OldAPIMap, RTInfo +from unit_tests.utils.graph import build_graph, result, connect, regular_op_with_shaped_data, valued_const_with_data nodes = { - **regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter'}), + **regular_op_with_shaped_data('placeholder1', [1, 3, 10, 10], {'type': 'Parameter', 'rt_info': RTInfo()}), **regular_op_with_shaped_data('placeholder2', [1, 1, 1, 1], {'type': 'Parameter'}), **regular_op_with_shaped_data('mul', [1, 3, 10, 10], {'type': 'Multiply'}), @@ -35,6 +38,17 @@ nodes2 = { **result('result2'), } +nodes3 = { + **regular_op_with_shaped_data('placeholder', [1, 3, 10, 10], {'type': 'Parameter'}), + **regular_op_with_shaped_data('transpose', [1, 3, 10, 10], {'type': 'Transpose'}), + **valued_const_with_data('transpose_order', int64_array([0, 3, 1, 2])), + **regular_op_with_shaped_data('reverse_channels_up', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 3}), + **regular_op_with_shaped_data('reverse_channels_down', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}), + **result('result'), + **result('result2'), +} + + class ReverseInputChannelsTest(unittest.TestCase): def check_graph_attrs(self, graph: Graph, parameter_node_names: list): for node in graph.get_op_nodes(): @@ -75,12 +89,11 @@ class ReverseInputChannelsTest(unittest.TestCase): node = Node(graph, 'pad') reverse_channels = Node(graph, 'reverse_channels') - keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels) + keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels) self.assertTrue(keep_moving_up is True) self.assertTrue(len(new_reverses) == 1) self.check_graph_attrs(graph, ['placeholder']) - def test_lift_up_through_pad2(self): graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'), *connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'), @@ -91,12 +104,11 @@ class ReverseInputChannelsTest(unittest.TestCase): node = Node(graph, 'pad') reverse_channels = Node(graph, 'reverse_channels') - keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels) + keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels) self.assertTrue(keep_moving_up is True) self.assertTrue(len(new_reverses) == 1) self.check_graph_attrs(graph, ['placeholder']) - def test_pass_rc_through(self): graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'), *connect('mul', 'reverse_channels'), *connect('reverse_channels', '0:pad'), @@ -107,5 +119,114 @@ class ReverseInputChannelsTest(unittest.TestCase): node = Node(graph, 'pad') reverse_channels = Node(graph, 'reverse_channels') - ReverseChannelsPropagationDown.pass_rc_through(node, reverse_channels) + ReverseChannelsPropagationDown.pass_rc_through_zero_port_only(node, reverse_channels) self.check_graph_attrs(graph, ['placeholder']) + + def test_lift_up_through_transpose(self): + graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'), + *connect('transpose', 'reverse_channels_down'), + *connect('reverse_channels_down', 'result')]) + graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'), + *connect('transpose_order', '1:transpose'), + *connect('reverse_channels_down', 'transpose'), + *connect('transpose', 'result')]) + self.set_graph_attrs(graph, ['placeholder']) + + node = Node(graph, 'transpose') + reverse_channels = Node(graph, 'reverse_channels_down') + + keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels) + self.assertTrue(keep_moving_up is True) + self.assertTrue(len(new_reverses) == 1) + self.check_graph_attrs(graph, ['placeholder']) + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + reverse_channels = Node(graph, 'reverse_channels_down') + self.assertTrue(reverse_channels.axis == 3) + + def test_lift_down_through_transpose(self): + graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'), + *connect('transpose_order', '1:transpose'), + *connect('reverse_channels_up', '0:transpose'), + *connect('transpose', 'result')]) + graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'), + *connect('transpose_order', '1:transpose'), + *connect('transpose', 'reverse_channels_up'), + *connect('reverse_channels_up', '0:result')]) + self.set_graph_attrs(graph, ['placeholder']) + + node = Node(graph, 'transpose') + reverse_channels = Node(graph, 'reverse_channels_up') + + keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels) + + self.assertTrue(keep_moving_down is True) + self.check_graph_attrs(graph, ['placeholder']) + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + reverse_channels = Node(graph, 'reverse_channels_down') + self.assertTrue(reverse_channels.axis == 1) + + def test_lift_up_through_transpose_negative_axis(self): + graph = build_graph(nodes3, [*connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'), + *connect('transpose', 'reverse_channels_down'), + *connect('reverse_channels_down', 'result')]) + graph_ref = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_down'), + *connect('transpose_order', '1:transpose'), + *connect('reverse_channels_down', 'transpose'), + *connect('transpose', 'result')]) + self.set_graph_attrs(graph, ['placeholder']) + + node = Node(graph, 'transpose') + reverse_channels = Node(graph, 'reverse_channels_down') + reverse_channels.axis = -3 + + keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(node, reverse_channels) + self.assertTrue(keep_moving_up is True) + self.assertTrue(len(new_reverses) == 1) + self.check_graph_attrs(graph, ['placeholder']) + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + reverse_channels = Node(graph, 'reverse_channels_down') + self.assertTrue(reverse_channels.axis == 3) + + def test_lift_down_through_transpose_negative_axis(self): + graph = build_graph(nodes3, [*connect('placeholder', 'reverse_channels_up'), + *connect('transpose_order', '1:transpose'), + *connect('reverse_channels_up', '0:transpose'), + *connect('transpose', 'result')]) + graph_ref = build_graph(nodes3, [*connect('placeholder', '0:transpose'), + *connect('transpose_order', '1:transpose'), + *connect('transpose', 'reverse_channels_up'), + *connect('reverse_channels_up', '0:result')]) + self.set_graph_attrs(graph, ['placeholder']) + + node = Node(graph, 'transpose') + reverse_channels = Node(graph, 'reverse_channels_up') + reverse_channels.axis = -1 + + keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(node, reverse_channels) + + self.assertTrue(keep_moving_down is True) + self.check_graph_attrs(graph, ['placeholder']) + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + reverse_channels = Node(graph, 'reverse_channels_down') + self.assertTrue(reverse_channels.axis == 1) + + def test_get_fw_index(self): + graph = build_graph(nodes, [*connect('placeholder1', 'result')]) + node = Node(graph, 'placeholder1') + old_api_map = OldAPIMap(version=0) + node.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map + node.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1]) + self.assertTrue(InsertReverseChannels.get_fw_index(node, 0) == 0) + self.assertTrue(InsertReverseChannels.get_fw_index(node, 1) == 3) + self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1) + self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2) + self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1) +