Files
openvino/model-optimizer/extensions/back/ReverseInputChannels.py
Evgeny Lazarev 3775dad345 MO dynamic shapes support (#5918)
* Allow MO to generate IR with -1 in dimensions

* Some fixes to support -1 for StridedSlice operation

* Updated TensorArrayGatherV3 shape infer to support dynamic output shape

* Several fixes to support undefined dimensions in the Broadcast,Reshape,Slice and Tile

* Fixed bug in the normalization transformation of TF NMS to opset NMS

* Updated shape infer functions related to StridedSlice and NMS

* Updated Select shape inference function to use common shape broadcasting function supporting dynamism

* Fixed operation TFResize shape infer function to work correctly for case when model is converted with --disable_nhwc_to_nchw

* Dynamic Range and update asserts in NMS

* Changed the way how dynamic dimensions are specified. Refactored shape inference functions and common places to use new approach

* More fixes to support dynamic shapes

* More fixes for support of dynamic shapes

* Fixed generation of IR with dynamic dimensions

* Allow reading IRs with undefined dimensions

* More changes in the IE to support dynamic dimensions

* Fixes for Switch, Merge, Concat shape and value infer related to dynamism

* Fixed TensorArray related ops to properly handle dynamic dimensions. Fixed StridedSlice infer for case with new_axis

* Fixed shape_for_layout function to generate masked array

* Fixed shape inference for Convolution and Poolings to support dynamic spatial dimensions

* Updated shape infer functions for CTCGreedyDecotder, CTCLoss and Enter

* Fixed shape inference with dynamic dimensions for MatMul, Split, Upsample, SpaceToBatch, some fixes for the TI

* Fixes for undefined dimensions support for Proposal and DetectionOutput

* Fixed ExtractImagePatches, DepthToSpace and RegionYolo shape infer functions to work with partially dynamic dimensions

* Changes in tf_window_op_pad_infer to better work with dynamic dimensions

* Fixed output shape calculation for StridedSlice operation

* More StridedSlice fixes

* Fixed resolve_convolution_with_group

* Fixed unit tests

* Fixed unit tests

* Fixed Switch op unit tests

* Fixed shape inference for Upsample operation

* Updated unit tests for the Concat operation

* Fixed eltwise shape infer unit tests

* Fixed shape infer tests for Convolution and DetectionOutput ops

* Fixed Crop shape infer function tests

* Fixed Slice op unit test and minor fix in the shape inference. Fixed emitter

* Updated unit test for telemetry and match_shape function for dynamism

* Fixed unit test for the DetectionOutput

* Added support for the TF ClipByValue operation

* Fixed GatherND shape inference for dynamic shapes support

* Dynamic shapes support for the MO IR Reader

* Fixed BlockLSTM operation to not work as an extractor

* Allow to serialize IRs with partially defined shapes

* Updated SelectBroadcast transformation to not check shape values

* Fixed MO IR comparator

* Fixed SS value propagation when slices are dynamic

* Do not re-run graph clean-up for ProposalMutation

* Fixed InterpolateSequenceToInterpolate transformation to support dynamic dimensions

* Fixed Loop iteration count calculation and reading IteratorGetNext shapes

* Fixed unit test for serialization

* Fixed serialization test

* Fixed RandomUniform shape infer

* Fixed several transformations related to RNN to respect dynamic output shapes

* Fixed Deconvolutin shape calculation for dynamic batch. Eltwise shape infer improvements

* Fixed shape infer functions for ExperimentalDetectron ops, reverted changes for NonZero and removed debug prints

* Fixed check for dynamism of a list, fixed value propagation for Concat op and remove redundant shape infer for reshape

* Update Eltwise value propagation to use np.ma

* Fixed ExpandDims shape infer function

* Shape infer functions fixes and improvements

* Remove Accum op from the MO

* Updated activation functions shape infer

* Removed unsupported operation Correlation

* Fixed shape infers for several functions

* Removed unsupported DataAugmentation operation

* Fixed shape infer functions for several ops in extensions directory

* Removed not-support operation PowerFile

* Removed unsupported SpatialTransformer,SimplerNMS and PredictionHeatmap operations

* More shape infer functions updates

* Merge shape infer fix

* Fixed typo

* Fixed TensorArraySize shape infer function

* Fixed VariadicSplit and Squeeze shape infer

* Fixed ONNX models Parameter extractor

* Updated Select value propagation for the dynamic case

* Fixed ReorgYolo shape infer and test

* Removed unnecessary tests

* Fixed Tile shape infer

* Fixed SparseFillEmptryRows unit tests

* Fixed package bom

* Added extractor for the TF operation Mod

* Fixed value propagation for MatMul operation

* Updated Parameter extender to generate shape_array when shape is partially defined only

* Fixed BOM file

* Fixed issue with the TF OD API models and DetectionOutput op. Now the shape infer function for the DO do not re-infer "num_classes" attribute value if it is already known

* Fixed unit test for the DO infer

* Fixed num classes calculation for the DO generation for Faster/Mask-RCNN models

* Changed NMS op to produce static output shape

* Restore dynamic output shape calculation for the NMS for NMS-5

* Fixed CellNormalizer transformation. It should work for static shapes only

* RNNCell Op class fixes

* Revert some changes

* Updated documentation with a list of supported operations

* Revert changes

* Fixes for the ConstantFill op

* Removed redundant SequenceLengthToMask transformation

* TensorArray* ops shape infer code style and refactoring

* Reverse some unnecessary changes in the ConvolutionNormalizer

* Fixes and unit tests for shape_array, compare_shapes, is_fully_defined functions

* Implemented shape_insert, shape_delete functions and tests for them

* Modified code to use shape_delete function

* Added usage of shape_insert function where necessary

* Use shape_insert function in many places

* Some fixes in shape inference for various ops

* Updated shape_delete function to support negative indices

* Changes and unit tests for the MatMul infer function

* Removed strange code from the TF Merge infer function

* Merge op shape infer fixes

* Fixed value propagation in the transformation EltwiseInputReshape.py for the dynamic dimension case

* Code cleanup

* Updated GatherND to support dynamic dimensions

* Minor fixes

* Fixed shape_insert and shape_delete to support np.int64 and np.int32 types

* Updated Upsample operation unit tests with dynamic input shapes

* Minor change in the extensions/back/ConvolutionNormalizer.py to make sure that input dimensions are static

* Fixed ConvertGroupedStridedSlice transformation and added unit tests

* Revert debug changes

* Fixed value propagation for Unsqueeze to work with partially defined input values

* Typo fix

* Added unit tests for the Unsqueeze op shape infer

* broadcasting functions changes and unit tests

* Fixed Tile value inference for partially defined input tensor

* Unit tests for Split and VariadicSplit ops

* Fixes for the Concat infer + unit tests

* Removed redundant tf_pack shape infer

* Fixed Concat value infer and added unit tests

* Fixed StridedSlice shape inference for case with dynamic slices

* Fixes related to StridedSlice shape infer, changes in tests

* Unit tests for the eltwise shape and value infer

* Fixed Pad op value propagation to allow dynamic input values to be propagated

* Unit test for Pooling dynamic input shape infer

* Squeeze op unit tests for dynamic input shape

* Added assert to the Squeeze op shape infer for case when squeeze dimension is dynamic value

* Added message to the MO when input shapes are dynamic

* Convolution dynamic unit test

* Removed redundant transformation GroupedConvWeightsNormalize

* Removed non-ascii character from the message

* Fixed typo in the BOM file

* Code style and comment fixes

* Fixed copy-paste issue in the DO shape infer function

* Fixed setting dynamic shape in the MO command line

* Added function to compare tensor with dynamic values. Fixes in the unit tests and shape infer functions

* Improved Reshape shape infer + added unit tests

* Fixed value propagation for Select op

* Renamed several internal functions, minor code fixes.

* Code style fixes

* Modified condition in the _set_shape method of the Port class to not check shape if the "override_output_shape" attribute is specified

* Fixed constant value propagation for ReduceOps when inputs have dynamic values. Added unit test

* Fixed shape infer for the Loop for dynamic dimensions case

* Fix in the NMS shape infer to avoid ragged numpy array generation. Fixed Scatter shape infer validation

* Improved shapes infer for eltwise ops with respect to dynamic dimensions

* Changed code comments

* Renamed tensor names in the ClipByValueTFTransformation

* Changed np.ma.allequal to strict_compare_tensors in the Merge op infer

* Chanded np.ma.allequal with strict_compare_tensor.

* Fixed Merge op value infer

* Fixed debug code

* Removed commented line

* Updated condition to check for dynamic shapes in the Partial infer to not fail for MxNet models

* Improvements to the get_shape_from_slice and is_dynamic_slice functions

* Reverted change in the `normalize_slices_attr` for ellipsis mask case

* Updated shape conditions in the ScatterNDBase op to support dynamic dimensions

* Crop op file refactoring

* Set "type" attribute to None for SparseFillEmptyRows op which is not from any opset

* Removed unnecessary extractor test

* Restored Crop operation type

* Removed "type" attribute from the Crop operation and updated the MO code to find Crop by "op" attribute

* Fixed If shape infer function to produce dynamic dimensions

* Updated If shape and value infer to properly work when condition is static

* Fixed fusing transformation check to work with dynamic dimensions. Change comparison in the shape_inference function to not use strict shapes comparison

* Optimize imports in the LayerNorm

* ConvertGroupedStridedSlice minor fixes related to dynamism support

* Fixed ConvertGroupedStridedSlice to properly check if the dimension is sliced
2021-09-01 14:35:06 +03:00

472 lines
22 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from extensions.ops.gather import Gather
from extensions.ops.split import Split
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.graph.graph import Node
from mo.ops.concat import Concat
from mo.ops.op import Op, PermuteAttrs
class ReverseChannels(Op):
"""
Internal op that will never be emitted into IR and replaced by other, publicly supported ops
"""
op = 'ReverseChannels'
enabled = True
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': self.op,
'type': None,
'axis': int64_array(1),
'order': int64_array([2, 1, 0]),
'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
@staticmethod
def infer(node):
input_shape = node.in_port(0).data.get_shape()
assert input_shape is not None
node.out_port(0).data.set_shape(input_shape)
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
class InsertReverseChannels(BackReplacementPattern):
"""
Searches for all suitable nodes with type=Parameter and inserts internal ReverseChannels op right after them
TODO: we should provide user an ability to explicitly specify nodes for input channel reversing
"""
enabled = False
def find_and_replace_pattern(self, graph: Graph):
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
for p in graph.get_op_nodes(type='Parameter')]
suitable_params = [(name, p, shape) for name, p, shape in all_params if len(shape) == 4 and shape[1] == 3]
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape in suitable_params}))
if len(suitable_params) < len(all_params):
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n'
'Suitable inputs {}'.format(len(all_params), len(suitable_params),
{name: shape for name, _, shape in all_params},
{name: shape for name, _, shape in suitable_params}),
extra={'is_warning': True})
for name, parameter, _ in suitable_params:
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels'}).create_node()
parameter.out_port(0).get_connection().set_source(reverse_channels.out_port(0),
attributes_save_mode='source')
parameter.out_port(0).connect(reverse_channels.in_port(0))
class ReverseChannelsPropagationDown(BackReplacementPattern):
"""
Propagates ReverseChannels operations down through nodes that we have rules for
"""
enabled = False
propagation_rules = {
'Convolution': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_conv(node, rc),
'ScaleShift': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Power': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'BatchNormalization': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'FakeQuantize': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Multiply': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Divide': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Add': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Subtract': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Pow': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Convert': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Shape': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
'ShapeOf': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_shape(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through(node, rc),
}
@staticmethod
def pass_rc_through(node: Node, reverse_channels: Node):
r"""
BEFORE AFTER
previous_op
|
ReverseChannels previous_op previous_op previous_op
\ / \ /
Node Node
|
ReverseChannels
returns boolean value whatever we should continue propagating current ReverseChannels operation down or not
"""
# detaching reverse_channels node from the graph
if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0)\
and node.is_out_port_connected(0):
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
reverse_channels.in_port(0).disconnect()
node.out_port(0).get_connection().set_source(reverse_channels.out_port(0))
node.out_port(0).disconnect()
node.out_port(0).connect(reverse_channels.in_port(0))
return True
return False
@staticmethod
def pass_rc_through_conv(node, reverse_channels):
r"""
For non grouped convolution:
BEFORE AFTER
previous_op weights
| |
ReverseChannels weights previous_op ReverseChannels
\ / \ /
Conv Conv
For grouped convolution:
BEFORE AFTER
previous_op weights
| |
ReverseChannels weights previous_op ReverseChannels
\ / \ /
Conv Conv
|
ReverseChannels
returns boolean value whatever we should continue propagating current ReverseChannels operation down or not
"""
channel_idx = node.soft_get("input_feature_channel", None)
if channel_idx is None:
# unknown Convolution configuration, won't propagate reverse_channels down the network
return False
weights_shape = node.in_port(1).data.get_shape()
if weights_shape is None or weights_shape[channel_idx] != reverse_channels.order.size:
# unexpected Convolution configuration, won't propagate reverse_channels down the network
return False
# detaching reverse_channels node from the graph
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
reverse_channels.in_port(0).disconnect()
group = node.soft_get('group', 1)
# insert ReverseChannels on weights port of Convolution
ric_to_move_to_weights = reverse_channels if group == 1 else reverse_channels.copy_node()
ric_to_move_to_weights['axis'] = np.array(channel_idx)
src = node.in_port(1).get_connection().get_source()
node.in_port(1).get_connection().set_source(ric_to_move_to_weights.out_port(0))
src.disconnect()
src.connect(ric_to_move_to_weights.in_port(0))
if group != 1 and group == reverse_channels.order.size:
# grouped Convolution weights channel reversing is not enough to complete channel reversing procedure
# we propagate ReverseChannels op through current Convolution with new order value for channel permutation
bottom_channels = node.out_port(0).data.get_shape()[node.channel_dims[0]]
assert bottom_channels % group == 0
multiplier = int(bottom_channels / group)
new_order = np.take(np.arange(bottom_channels).reshape((group, multiplier)),
indices=reverse_channels.order, axis=0).flatten()
reverse_channels['axis'] = np.array(reverse_channels.axis.copy())
reverse_channels['order'] = np.array(new_order)
node.out_port(0).get_connection().set_source(reverse_channels.out_port(0))
node.out_port(0).disconnect()
node.out_port(0).connect(reverse_channels.in_port(0))
# as described above, we are not done reversing channels yet, so we should continue propagating
# ReverseChannels operation down the network
return True
# we reversed channels for sure, nothing to propagate down the network
return False
@staticmethod
def pass_rc_through_eltwise(node, reverse_channels):
r"""
BEFORE AFTER
previous_op previous_op'
| |
ReverseChannels previous_op' previous_op ReverseChannels
\ / \ /
Eltwise Eltwise
|
ReverseChannels
returns boolean value whatever we should continue propagating current ReverseChannels operation down or not
"""
before_shape = reverse_channels.out_port(0).data.get_shape()
port_axis = []
for idx, port in node.in_ports().items():
if port.get_connection().get_source().node.id == reverse_channels.id:
continue
shape = port.data.get_shape()
non_one_dims = np.where(shape != 1)[0]
if shape[reverse_channels.axis] == 1:
continue # nothing to flip for this input
if len(non_one_dims) == 1 and shape[non_one_dims.item()] == reverse_channels.order.size:
new_axis = non_one_dims.item()
elif np.array_equal(before_shape, shape):
new_axis = reverse_channels.axis
else:
# shape has multiple non-one values and shape is not fully broadcasted to value port shape
# it is safe not to propagate reverse channels
return False
port_axis.append((port, new_axis))
# reversing eltwise inputs where applicable
for port, axis in port_axis:
ric_copy = reverse_channels.copy_node({'axis': np.array(axis), 'order': np.array(reverse_channels.order)})
src = port.get_connection().get_source()
port.get_connection().set_source(ric_copy.out_port(0))
src.disconnect()
src.connect(ric_copy.in_port(0))
# detaching reverse_channels node from the graph
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
reverse_channels.in_port(0).disconnect()
# propagating reverse_channels node to the output port of eltwise
node.out_port(0).get_connection().set_source(reverse_channels.out_port(0))
node.out_port(0).disconnect()
node.out_port(0).connect(reverse_channels.in_port(0))
# propagated reverse_channels successfully through current node, will continue propagation
return True
@staticmethod
def pass_rc_through_shape(node, reverse_channels):
"""
stops propagation of RIC through shape taking operations, due to RIC does not change shape
"""
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
return False
@staticmethod
def get_non_shape_taking_dst(dsts):
return [dst for dst in dsts if dst.node.soft_get('type') not in ['Shape', 'ShapeOf']]
def check_if_we_propagate_down(self, reverse_channels):
dsts = self.get_non_shape_taking_dst(reverse_channels.out_port(0).get_destinations())
return len(dsts) == 1 and dsts[0].node.soft_get('type') in self.propagation_rules
def find_and_replace_pattern(self, graph: Graph):
for reverse_channels in graph.get_op_nodes(op='ReverseChannels'):
keep_moving_down = True
while keep_moving_down and self.check_if_we_propagate_down(reverse_channels):
next_node = self.get_non_shape_taking_dst(reverse_channels.out_port(0).get_destinations())[0].node
keep_moving_down = self.propagation_rules[next_node.type](next_node, reverse_channels)
class ReverseChannelsPropagationUp(BackReplacementPattern):
"""
Propagates ReverseChannels operations up through nodes that we have rules for
"""
enabled = False
propagation_rules = {
'ScaleShift': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Power': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'BatchNormalization': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'FakeQuantize': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Multiply': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Divide': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Add': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through(node, rc),
}
@staticmethod
def lift_up_through(node: Node, reverse_channels: Node):
r"""
BEFORE AFTER
previous_op
\
previous_op previous_op ReverseChannels previous_op
\ / \ /
Node Node
| |
ReverseChannels next_op
|
next_op
returns boolean value whatever we should continue propagating current ReverseChannels operation up or not
"""
if node.is_in_port_connected(0):
node_input_port_0 = node.in_port(0)
reverse_channels_out_npde = reverse_channels.out_port(0).get_connection().get_destination().node
reverse_channels.out_port(0).disconnect()
src = node_input_port_0.get_connection().get_source()
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
src.connect(reverse_channels.in_port(0))
node.out_port(0).get_connection().set_destination(reverse_channels_out_npde.in_port(0))
return True
return False
@staticmethod
def lift_up_through_eltwise(node: Node, reverse_channels: Node):
r"""
BEFORE AFTER
previous_op previous_op'
\ /
previous_op previous_op' ReverseChannels ReverseChannels
\ / \ /
Eltwise Eltwise
| |
ReverseChannels next_op
|
next_op
returns two objects:
first - boolean value whatever we should continue propagating current ReverseChannels operation up or not
second - list of new ReverseChannels operations that were produced while propagating reverse_channels up
"""
before_shape = reverse_channels.in_port(0).data.get_shape()
port_axis = []
for idx, port in node.in_ports().items():
shape = port.data.get_shape()
non_one_dims = np.where(shape != 1)[0]
if shape[reverse_channels.axis] == 1:
continue # nothing to flip for this input
if len(non_one_dims) == 1 and shape[non_one_dims.item()] == reverse_channels.order.size:
axis = non_one_dims.item()
elif np.array_equal(before_shape, shape):
axis = reverse_channels.axis
else:
# shape has multiple non-one values and shape is not fully broadcasted to value port shape
# it is safe not to propagate reverse channels
return False, []
port_axis.append((port, axis))
copies = []
for port, axis in port_axis:
reverse_channels_copy = reverse_channels.copy_node({'axis': np.array(axis)})
src = port.get_connection().get_source()
if src.node.soft_get('type') == 'Parameter':
# For Parameter nodes tensor debug attributes should not move to the last node
# of subgraph. It is needed for the proper mapping of input framework name.
# For this reason "source" mode is used to keep tensor debug attributes at Parameter node.
port.get_connection().set_source(reverse_channels_copy.out_port(0), attributes_save_mode="source")
else:
port.get_connection().set_source(reverse_channels_copy.out_port(0))
src.connect(reverse_channels_copy.in_port(0))
copies.append(reverse_channels_copy)
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
reverse_channels.in_port(0).disconnect()
# propagated reverse_channels successfully through current node, will continue propagation
return True, copies
def find_and_replace_pattern(self, graph: Graph):
reverse_channels = set(graph.get_op_nodes(op='ReverseChannels'))
while len(reverse_channels):
keep_moving_up = True
while keep_moving_up:
curr_reverse_channels = reverse_channels.pop()
if curr_reverse_channels.in_port(0).get_source().node.soft_get('type') not in self.propagation_rules:
break
next_op = curr_reverse_channels.in_port(0).get_source().node
keep_moving_up, new_reverses = self.propagation_rules[next_op.type](next_op, curr_reverse_channels)
reverse_channels.update(new_reverses)
class DecomposeReverseChannels(BackReplacementPattern):
"""
Replaces each internal ReverseChannels operation in graph with publicly supported Gather operation
"""
enabled = False
@staticmethod
def replace_with_gather(node):
graph = node.graph
name = node.soft_get('name', node.id)
axis = node.axis
order = node.order
gather = create_op_with_const_inputs(graph, Gather, {1: order, 2: int64_array(axis)}, {'name': name})
node.out_port(0).get_connection().set_source(gather.out_port(0))
node.in_port(0).get_connection().set_destination(gather.in_port(0))
@staticmethod
def replace_with_split_concat(node):
graph = node.graph
name = node.soft_get('name', node.id)
axis = node.axis
order = node.order
split = create_op_with_const_inputs(graph, Split, {1: int64_array(axis)},
{'name': name + '/Split', 'num_splits': order.size})
concat = Concat(graph, {'name': name + '/Concat', 'axis': axis, 'in_ports_count': order.size}).create_node()
for out_port_idx, in_port_idx in enumerate(order):
split.out_port(out_port_idx).connect(concat.in_port(in_port_idx))
node.out_port(0).get_connection().set_source(concat.out_port(0))
node.in_port(0).get_connection().set_destination(split.in_port(0))
graph.remove_node(node.id)
def find_and_replace_pattern(self, graph: Graph):
for reverse_channels in graph.get_op_nodes(op='ReverseChannels'):
if reverse_channels.in_port(0).disconnected() or reverse_channels.out_port(0).disconnected():
# graph.clean_up will delete it
reverse_channels['need_shape_inference'] = False
continue
self.replace_with_split_concat(reverse_channels)
class ApplyReverseChannels(BackReplacementPattern):
"""
Reverses input channels for suitable Parameter operation if requested by user
Optimizes channel reversing by fusion to Convolution weights if applicable
"""
enabled = True
run_not_recursively = True
force_clean_up = True
def find_and_replace_pattern(self, graph: Graph):
"""
Following transformations should run in strict order, that is why we disabled them all and run here
"""
if graph.graph['cmd_params'].reverse_input_channels:
InsertReverseChannels().find_and_replace_pattern(graph)
ReverseChannelsPropagationDown().find_and_replace_pattern(graph)
ReverseChannelsPropagationUp().find_and_replace_pattern(graph)
DecomposeReverseChannels().find_and_replace_pattern(graph)