""" Copyright (C) 2018-2020 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging as log import numpy as np from extensions.ops.gather import Gather from extensions.ops.split import Split from mo.back.replacement import BackReplacementPattern from mo.front.common.partial_infer.utils import int64_array from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph from mo.graph.graph import Node from mo.ops.concat import Concat from mo.ops.op import Op, PermuteAttrs class ReverseChannels(Op): """ Internal op that will never be emitted into IR and replaced by other, publicly supported ops """ op = 'ReverseChannels' enabled = True def __init__(self, graph: Graph, attrs: dict): super().__init__(graph, { 'op': self.op, 'type': None, 'axis': int64_array(1), 'order': int64_array([2, 1, 0]), 'infer': self.infer, 'in_ports_count': 1, 'out_ports_count': 1, }, attrs) @staticmethod def infer(node): input_shape = node.in_port(0).data.get_shape() assert input_shape is not None node.out_port(0).data.set_shape(input_shape) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) class InsertReverseChannels(BackReplacementPattern): """ Searches for all suitable nodes with type=Parameter and inserts internal ReverseChannels op right after them TODO: we should provide user an ability to explicitly specify nodes for input channel reversing """ enabled = False def find_and_replace_pattern(self, graph: Graph): all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape())) for p in graph.get_op_nodes(type='Parameter')] suitable_params = [(name, p, shape) for name, p, shape in all_params if len(shape) == 4 and shape[1] == 3] log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params})) log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape in suitable_params})) if len(suitable_params) < len(all_params): log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n' 'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n' 'Suitable inputs {}'.format(len(all_params), len(suitable_params), {name: shape for name, _, shape in all_params}, {name: shape for name, _, shape in suitable_params}), extra={'is_warning': True}) for name, parameter, _ in suitable_params: reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels'}).create_node() parameter.out_port(0).get_connection().set_source(reverse_channels.out_port(0)) parameter.out_port(0).connect(reverse_channels.in_port(0)) class ReverseChannelsPropagationDown(BackReplacementPattern): """ Propagates ReverseChannels operations down through nodes that we have rules for """ enabled = False propagation_rules = { 'Convolution': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_conv(node, rc), 'ScaleShift': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'Power': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'BatchNormalization': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'FakeQuantize': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'Multiply': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'Add': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'Pow': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'Convert': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc), 'Shape': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc), 'ShapeOf': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc), } @staticmethod def pass_rc_through_conv(node, reverse_channels): """ For non grouped convolution: BEFORE AFTER previous_op weights | | ReverseChannels weights previous_op ReverseChannels \ / \ / Conv Conv For grouped convolution: BEFORE AFTER previous_op weights | | ReverseChannels weights previous_op ReverseChannels \ / \ / Conv Conv | ReverseChannels returns boolean value whatever we should continue propagating current ReverseChannels operation down or not """ channel_idx = node.soft_get("input_feature_channel", None) if channel_idx is None: # unknown Convolution configuration, won't propagate reverse_channels down the network return False weights_shape = node.in_port(1).data.get_shape() if weights_shape is None or weights_shape[channel_idx] != reverse_channels.order.size: # unexpected Convolution configuration, won't propagate reverse_channels down the network return False # detaching reverse_channels node from the graph reverse_channels.out_port(0).get_connection().set_source( reverse_channels.in_port(0).get_connection().get_source()) reverse_channels.in_port(0).disconnect() group = node.soft_get('group', 1) # insert ReverseChannels on weights port of Convolution ric_to_move_to_weights = reverse_channels if group == 1 else reverse_channels.copy_node() ric_to_move_to_weights['axis'] = np.array(channel_idx) src = node.in_port(1).get_connection().get_source() node.in_port(1).get_connection().set_source(ric_to_move_to_weights.out_port(0)) src.disconnect() src.connect(ric_to_move_to_weights.in_port(0)) if group != 1 and group == reverse_channels.order.size: # grouped Convolution weights channel reversing is not enough to complete channel reversing procedure # we propagate ReverseChannels op through current Convolution with new order value for channel permutation bottom_channels = node.out_port(0).data.get_shape()[node.channel_dims[0]] assert bottom_channels % group == 0 multiplier = int(bottom_channels / group) new_order = np.take(np.arange(bottom_channels).reshape((group, multiplier)), indices=reverse_channels.order, axis=0).flatten() reverse_channels['axis'] = np.array(reverse_channels.axis.copy()) reverse_channels['order'] = np.array(new_order) node.out_port(0).get_connection().set_source(reverse_channels.out_port(0)) node.out_port(0).disconnect() node.out_port(0).connect(reverse_channels.in_port(0)) # as described above, we are not done reversing channels yet, so we should continue propagating # ReverseChannels operation down the network return True # we reversed channels for sure, nothing to propagate down the network return False @staticmethod def pass_rc_through_eltwise(node, reverse_channels): """ BEFORE AFTER previous_op previous_op' | | ReverseChannels previous_op' previous_op ReverseChannels \ / \ / Eltwise Eltwise | ReverseChannels returns boolean value whatever we should continue propagating current ReverseChannels operation down or not """ before_shape = reverse_channels.out_port(0).data.get_shape() port_axis = [] for idx, port in node.in_ports().items(): if port.get_connection().get_source().node.id == reverse_channels.id: continue shape = port.data.get_shape() non_one_dims = np.where(shape != 1)[0] if len(non_one_dims) == 0: # shape contains only ones - nothing to flip for this input continue if len(non_one_dims) == 1 and shape[non_one_dims.item()] == reverse_channels.order.size: new_axis = non_one_dims.item() elif np.array_equal(before_shape, shape): new_axis = reverse_channels.axis else: # shape has multiple non-one values and shape is not fully broadcasted to value port shape # it is safe not to propagate reverse channels return False port_axis.append((port, new_axis)) # reversing eltwise inputs where applicable for port, axis in port_axis: ric_copy = reverse_channels.copy_node({'axis': np.array(axis), 'order': np.array(reverse_channels.order)}) src = port.get_connection().get_source() port.get_connection().set_source(ric_copy.out_port(0)) src.disconnect() src.connect(ric_copy.in_port(0)) # detaching reverse_channels node from the graph reverse_channels.out_port(0).get_connection().set_source( reverse_channels.in_port(0).get_connection().get_source()) reverse_channels.in_port(0).disconnect() # propagating reverse_channels node to the output port of eltwise node.out_port(0).get_connection().set_source(reverse_channels.out_port(0)) node.out_port(0).disconnect() node.out_port(0).connect(reverse_channels.in_port(0)) # propagated reverse_channels successfully through current node, will continue propagation return True @staticmethod def pass_rc_through_shape(node, reverse_channels): """ stops propagation of RIC through shape taking operations, due to RIC does not change shape """ reverse_channels.out_port(0).get_connection().set_source(reverse_channels.in_port(0).get_connection().get_source()) return False @staticmethod def get_non_shape_taking_dst(dsts): return [dst for dst in dsts if dst.node.soft_get('type') not in ['Shape', 'ShapeOf']] def check_if_we_propagate_down(self, reverse_channels): dsts = self.get_non_shape_taking_dst(reverse_channels.out_port(0).get_destinations()) return len(dsts) == 1 and dsts[0].node.soft_get('type') in self.propagation_rules def find_and_replace_pattern(self, graph: Graph): for reverse_channels in graph.get_op_nodes(op='ReverseChannels'): keep_moving_down = True while keep_moving_down and self.check_if_we_propagate_down(reverse_channels): next_node = self.get_non_shape_taking_dst(reverse_channels.out_port(0).get_destinations())[0].node keep_moving_down = self.propagation_rules[next_node.type](next_node, reverse_channels) class ReverseChannelsPropagationUp(BackReplacementPattern): """ Propagates ReverseChannels operations up through nodes that we have rules for """ enabled = False propagation_rules = { 'ScaleShift': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Power': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'BatchNormalization': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'FakeQuantize': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Multiply': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Add': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), } @staticmethod def lift_up_through_eltwise(node: Node, reverse_channels: Node): """ BEFORE AFTER previous_op previous_op' \ / previous_op previous_op' ReverseChannels ReverseChannels \ / \ / Eltwise Eltwise | | ReverseChannels next_op | next_op returns two objects: first - boolean value whatever we should continue propagating current ReverseChannels operation up or not second - list of new ReverseChannels operations that were produced while propagating reverse_channels up """ before_shape = reverse_channels.in_port(0).data.get_shape() port_axis = [] for idx, port in node.in_ports().items(): shape = port.data.get_shape() non_one_dims = np.where(shape != 1)[0] if len(non_one_dims) == 0: # shape contains only ones - nothing to flip for this input continue if len(non_one_dims) == 1 and shape[non_one_dims.item()] == reverse_channels.order.size: axis = non_one_dims.item() elif np.array_equal(before_shape, shape): axis = reverse_channels.axis else: # shape has multiple non-one values and shape is not fully broadcasted to value port shape # it is safe not to propagate reverse channels return False, [] port_axis.append((port, axis)) copies = [] for port, axis in port_axis: reverse_channels_copy = reverse_channels.copy_node({'axis': np.array(axis)}) src = port.get_connection().get_source() port.get_connection().set_source(reverse_channels_copy.out_port(0)) src.connect(reverse_channels_copy.in_port(0)) copies.append(reverse_channels_copy) reverse_channels.out_port(0).get_connection().set_source( reverse_channels.in_port(0).get_connection().get_source()) reverse_channels.in_port(0).disconnect() # propagated reverse_channels successfully through current node, will continue propagation return True, copies def find_and_replace_pattern(self, graph: Graph): reverse_channels = set(graph.get_op_nodes(op='ReverseChannels')) while len(reverse_channels): keep_moving_up = True while keep_moving_up: curr_reverse_channels = reverse_channels.pop() if curr_reverse_channels.in_port(0).get_source().node.soft_get('type') not in self.propagation_rules: break next_op = curr_reverse_channels.in_port(0).get_source().node keep_moving_up, new_reverses = self.propagation_rules[next_op.type](next_op, curr_reverse_channels) reverse_channels.update(new_reverses) class DecomposeReverseChannels(BackReplacementPattern): """ Replaces each internal ReverseChannels operation in graph with publicly supported Gather operation """ enabled = False @staticmethod def replace_with_gather(node): graph = node.graph name = node.soft_get('name', node.id) axis = node.axis order = node.order gather = create_op_with_const_inputs(graph, Gather, {1: order, 2: int64_array(axis)}, {'name': name}) node.out_port(0).get_connection().set_source(gather.out_port(0)) node.in_port(0).get_connection().set_destination(gather.in_port(0)) @staticmethod def replace_with_split_concat(node): graph = node.graph name = node.soft_get('name', node.id) axis = node.axis order = node.order split = create_op_with_const_inputs(graph, Split, {1: int64_array(axis)}, {'name': name + '/Split', 'num_splits': order.size}) concat = Concat(graph, {'name': name + '/Concat', 'axis': axis, 'in_ports_count': order.size}).create_node() for out_port_idx, in_port_idx in enumerate(order): split.out_port(out_port_idx).connect(concat.in_port(in_port_idx)) node.out_port(0).get_connection().set_source(concat.out_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) graph.remove_node(node.id) def find_and_replace_pattern(self, graph: Graph): for reverse_channels in graph.get_op_nodes(op='ReverseChannels'): if reverse_channels.in_port(0).disconnected() or reverse_channels.out_port(0).disconnected(): # graph.clean_up will delete it reverse_channels['need_shape_inference'] = False continue self.replace_with_split_concat(reverse_channels) class ApplyReverseChannels(BackReplacementPattern): """ Reverses input channels for suitable Parameter operation if requested by user Optimizes channel reversing by fusion to Convolution weights if applicable """ enabled = True run_not_recursively = True force_clean_up = True def run_before(self): from extensions.back.GroupedConvWeightsNormalize import GroupedConvWeightsNormalize return [GroupedConvWeightsNormalize] def find_and_replace_pattern(self, graph: Graph): """ Following transformations should run in strict order, that is why we disabled them all and run here """ if graph.graph['cmd_params'].reverse_input_channels: InsertReverseChannels().find_and_replace_pattern(graph) ReverseChannelsPropagationDown().find_and_replace_pattern(graph) ReverseChannelsPropagationUp().find_and_replace_pattern(graph) DecomposeReverseChannels().find_and_replace_pattern(graph)