* SplitConcatPairToInterpolate transformation was moved to middle stage and is applied only for 4D and 5D inputs.
171 lines
7.5 KiB
Python
171 lines
7.5 KiB
Python
"""
|
|
Copyright (c) 2020 Intel Corporation
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import logging as log
|
|
from typing import Optional
|
|
|
|
from extensions.ops.elementwise import Mul
|
|
from extensions.ops.interpolate import Interpolate
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
|
from mo.graph.graph import Graph, Node
|
|
from mo.middle.replacement import MiddleReplacementPattern
|
|
from mo.ops.const import Const
|
|
from mo.ops.shape import Shape
|
|
from mo.ops.strided_slice import StridedSlice
|
|
|
|
|
|
def get_concat_after_split(split: Node) -> Optional[Node]:
|
|
# If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
|
|
split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
|
|
names_of_split_outputs = set([n.name for n in split_outputs])
|
|
if len(names_of_split_outputs) != 1:
|
|
return
|
|
|
|
groups_of_inputs = [[d.idx for d in p.get_connection().get_destinations()] for _, p in split.out_ports().items()]
|
|
sizes_of_groups = set([len(g) for g in groups_of_inputs])
|
|
# If numbers of consumer ports are various for various output ports of 'split', then the transformation
|
|
# is not applicable.
|
|
if len(sizes_of_groups) != 1:
|
|
return
|
|
# The transformation is applicable iff output port 0 of 'split' goes to ports [0, ..., m-1] of next node,
|
|
# output port 1 of 'split' goes to ports [m, ..., m + (m-1)] of next node, ..., output port i of 'split'
|
|
# goes to ports [i * m, ..., i * m + (m - 1)], and so on.
|
|
flatten_groups = [i for g in groups_of_inputs for i in g]
|
|
if flatten_groups != list(range(0, len(flatten_groups))):
|
|
return
|
|
|
|
dest = split.out_port(0).get_destinations()[0].node
|
|
# The transformation is applicable, only if next node is Concat.
|
|
return dest if dest.soft_get('type') == 'Concat' else None
|
|
|
|
|
|
def get_interpolate_pattern(split: Node) -> dict:
|
|
split_shape = split.in_port(0).data.get_shape()
|
|
if len(split_shape) not in {4, 5}:
|
|
return {}
|
|
concat = get_concat_after_split(split)
|
|
if concat is None:
|
|
return {}
|
|
return {'split': split, 'concat': concat}
|
|
|
|
|
|
def get_split_scale(split: Node) -> int:
|
|
split_dests = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
|
|
num_of_split_dests = len(split_dests)
|
|
num_of_split_out_ports = len(split.out_ports())
|
|
fractional_part = num_of_split_dests / num_of_split_out_ports - num_of_split_dests // num_of_split_out_ports
|
|
assert fractional_part == 0, "Number of output ports of Split must be multiple of number of inputs of Concat"
|
|
return len(split_dests) // len(split.out_ports())
|
|
|
|
|
|
def replace_interpolate_pattern(graph: Graph, match: dict):
|
|
split = match['split']
|
|
scale = int64_array([get_split_scale(split)])
|
|
axis = int(split.in_port(1).get_connection().get_source().node.value)
|
|
split_node_name = split.name
|
|
|
|
shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
|
|
scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
|
|
mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
|
|
scales_node.out_port(0).connect(mul_node.in_port(1))
|
|
|
|
strided_slice_node = create_op_with_const_inputs(graph,
|
|
StridedSlice,
|
|
{1: int64_array([axis]), 2: int64_array([axis + 1])},
|
|
{
|
|
'name': split_node_name + '/StridedSlice_',
|
|
'begin_mask': int64_array([1]),
|
|
'end_mask': int64_array([1]),
|
|
'new_axis_mask': int64_array([0]),
|
|
'shrink_axis_mask': int64_array([0]),
|
|
'ellipsis_mask': int64_array([0])
|
|
})
|
|
shape_node.out_port(0).connect(strided_slice_node.in_port(0))
|
|
|
|
strided_slice_node.out_port(0).connect(mul_node.in_port(0))
|
|
|
|
interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
|
|
axes=int64_array([axis]),
|
|
mode='nearest')).create_node()
|
|
mul_node.out_port(0).connect(interp_node.in_port(1))
|
|
|
|
match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))
|
|
|
|
split_connection = split.in_port(0).get_connection()
|
|
split_connection.set_destination(interp_node.in_port(0))
|
|
split_connection.get_source().connect(shape_node.in_port(0))
|
|
|
|
|
|
class SplitConcatPairToInterpolate(MiddleReplacementPattern):
|
|
"""
|
|
This transformation looks for Interpolation layer implemented using simple operations, i.e. Split and Concat,
|
|
and replaces found pattern with a sequence of Shape, StridedSlice, Const, Mul, Interpolate.
|
|
|
|
Found pattern:
|
|
nodes=[
|
|
('split', dict(kind='op', op='Split')),
|
|
('concat', dict(kind='op', op='Concat')),
|
|
],
|
|
edges=[
|
|
('split', 'concat'),
|
|
]
|
|
|
|
Here we assume that
|
|
1) 'split' is in NDHWC layout and is a 5D-tensor;
|
|
2) split dimensions for 'split' belongs to {1, 2, 3};
|
|
3) all outputs of 'split' go to only inputs of 'concat';
|
|
4) 'concat' takes inputs only from 'split';
|
|
5) split_dim of 'split' is equal to axis of 'concat'.
|
|
|
|
Found pattern will be replaced with
|
|
nodes=[
|
|
('shape', dict(kind='op', op='Shape')),
|
|
('strided_slice', dict(kind='op', op='StridedSlice')),
|
|
('scales', dict(kind='op', op='Const')),
|
|
('scaled_shape', dict(kind='op', op='Mul')),
|
|
('interp', dict(kind='op', op='Interpolate'))
|
|
],
|
|
edges=[
|
|
('shape', 'strided_slice', {'in': 0}),
|
|
('strided_slice', 'scaled_shape', {'in': 0}),
|
|
('scales', 'scaled_shape', {'in': 1}),
|
|
('scaled_shape', 'interp', {'in': 1}),
|
|
]
|
|
|
|
Here scaling factor in Interpolate is equal to a quotient of dividing number of input ports of 'concat'
|
|
by number of output ports of 'split'.
|
|
"""
|
|
enabled = True
|
|
force_clean_up = True
|
|
|
|
def run_before(self):
|
|
from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
|
|
return [InterpolateSequenceToInterpolate]
|
|
|
|
def find_and_replace_pattern(self, graph: Graph):
|
|
log.debug('Enabled replacement of a pair of Split and Concat with Interpolate.')
|
|
splits = graph.get_op_nodes(op='Split')
|
|
patterns = []
|
|
|
|
for split_node in splits:
|
|
interpolate_pattern = get_interpolate_pattern(split_node)
|
|
if interpolate_pattern:
|
|
patterns.append(interpolate_pattern)
|
|
|
|
for pattern in patterns:
|
|
replace_interpolate_pattern(graph, pattern)
|