Files
openvino/model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py
Vladimir Gavrilov 33aca7d2c4 SplitConcatPairToInterpolate inserts Interpolate when input is 2D (#596)
* SplitConcatPairToInterpolate transformation was moved to middle stage and is applied only for 4D and 5D inputs.
2020-05-28 18:08:24 +03:00

171 lines
7.5 KiB
Python

"""
Copyright (c) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
from typing import Optional
from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice
def get_concat_after_split(split: Node) -> Optional[Node]:
# If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
names_of_split_outputs = set([n.name for n in split_outputs])
if len(names_of_split_outputs) != 1:
return
groups_of_inputs = [[d.idx for d in p.get_connection().get_destinations()] for _, p in split.out_ports().items()]
sizes_of_groups = set([len(g) for g in groups_of_inputs])
# If numbers of consumer ports are various for various output ports of 'split', then the transformation
# is not applicable.
if len(sizes_of_groups) != 1:
return
# The transformation is applicable iff output port 0 of 'split' goes to ports [0, ..., m-1] of next node,
# output port 1 of 'split' goes to ports [m, ..., m + (m-1)] of next node, ..., output port i of 'split'
# goes to ports [i * m, ..., i * m + (m - 1)], and so on.
flatten_groups = [i for g in groups_of_inputs for i in g]
if flatten_groups != list(range(0, len(flatten_groups))):
return
dest = split.out_port(0).get_destinations()[0].node
# The transformation is applicable, only if next node is Concat.
return dest if dest.soft_get('type') == 'Concat' else None
def get_interpolate_pattern(split: Node) -> dict:
split_shape = split.in_port(0).data.get_shape()
if len(split_shape) not in {4, 5}:
return {}
concat = get_concat_after_split(split)
if concat is None:
return {}
return {'split': split, 'concat': concat}
def get_split_scale(split: Node) -> int:
split_dests = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
num_of_split_dests = len(split_dests)
num_of_split_out_ports = len(split.out_ports())
fractional_part = num_of_split_dests / num_of_split_out_ports - num_of_split_dests // num_of_split_out_ports
assert fractional_part == 0, "Number of output ports of Split must be multiple of number of inputs of Concat"
return len(split_dests) // len(split.out_ports())
def replace_interpolate_pattern(graph: Graph, match: dict):
split = match['split']
scale = int64_array([get_split_scale(split)])
axis = int(split.in_port(1).get_connection().get_source().node.value)
split_node_name = split.name
shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
scales_node.out_port(0).connect(mul_node.in_port(1))
strided_slice_node = create_op_with_const_inputs(graph,
StridedSlice,
{1: int64_array([axis]), 2: int64_array([axis + 1])},
{
'name': split_node_name + '/StridedSlice_',
'begin_mask': int64_array([1]),
'end_mask': int64_array([1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])
})
shape_node.out_port(0).connect(strided_slice_node.in_port(0))
strided_slice_node.out_port(0).connect(mul_node.in_port(0))
interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
axes=int64_array([axis]),
mode='nearest')).create_node()
mul_node.out_port(0).connect(interp_node.in_port(1))
match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))
split_connection = split.in_port(0).get_connection()
split_connection.set_destination(interp_node.in_port(0))
split_connection.get_source().connect(shape_node.in_port(0))
class SplitConcatPairToInterpolate(MiddleReplacementPattern):
"""
This transformation looks for Interpolation layer implemented using simple operations, i.e. Split and Concat,
and replaces found pattern with a sequence of Shape, StridedSlice, Const, Mul, Interpolate.
Found pattern:
nodes=[
('split', dict(kind='op', op='Split')),
('concat', dict(kind='op', op='Concat')),
],
edges=[
('split', 'concat'),
]
Here we assume that
1) 'split' is in NDHWC layout and is a 5D-tensor;
2) split dimensions for 'split' belongs to {1, 2, 3};
3) all outputs of 'split' go to only inputs of 'concat';
4) 'concat' takes inputs only from 'split';
5) split_dim of 'split' is equal to axis of 'concat'.
Found pattern will be replaced with
nodes=[
('shape', dict(kind='op', op='Shape')),
('strided_slice', dict(kind='op', op='StridedSlice')),
('scales', dict(kind='op', op='Const')),
('scaled_shape', dict(kind='op', op='Mul')),
('interp', dict(kind='op', op='Interpolate'))
],
edges=[
('shape', 'strided_slice', {'in': 0}),
('strided_slice', 'scaled_shape', {'in': 0}),
('scales', 'scaled_shape', {'in': 1}),
('scaled_shape', 'interp', {'in': 1}),
]
Here scaling factor in Interpolate is equal to a quotient of dividing number of input ports of 'concat'
by number of output ports of 'split'.
"""
enabled = True
force_clean_up = True
def run_before(self):
from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
return [InterpolateSequenceToInterpolate]
def find_and_replace_pattern(self, graph: Graph):
log.debug('Enabled replacement of a pair of Split and Concat with Interpolate.')
splits = graph.get_op_nodes(op='Split')
patterns = []
for split_node in splits:
interpolate_pattern = get_interpolate_pattern(split_node)
if interpolate_pattern:
patterns.append(interpolate_pattern)
for pattern in patterns:
replace_interpolate_pattern(graph, pattern)