[ MO ] Hard-coded Interpolate followed by concat reshape-ability fixing (#818)

This commit is contained in:
Evgenya Stepyreva
2020-06-23 08:27:27 +03:00
committed by GitHub
parent 3490b985dd
commit f40338ff4b
4 changed files with 321 additions and 2 deletions

View File

@@ -128,6 +128,7 @@ extensions/front/global_pooling_to_reduce.py
extensions/front/image_scaler.py
extensions/front/input_cut.py
extensions/front/instance_normalization.py
extensions/front/interpolate_reshape.py
extensions/front/InterpolateNormalizer.py
extensions/front/kaldi/__init__.py
extensions/front/kaldi/add_permute_after_convolution.py

View File

@@ -0,0 +1,181 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.gather import Gather
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node
from mo.ops.shape import Shape
class InterpolateWithConcat(FrontReplacementPattern):
"""
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph using the following Concat inputs
BEFORE:
input Const
shape=[1, 3, 30, 40] value=[60, 160]
\ /
Interpolate(axes=(2, 3)) input_1
shape=[1, 3, 60, 160] shape=[1, 4, 60, 160]
\ /
Concat(axis=1)
shape=[1, 7, 60, 160]
AFTER:
input
shape=[1, 3, 30, 40] input_1
| shape=[1, 4, 60, 160]
| / |
| ShapeOf |
| | |
| Gather |
| indices=(2, 3); axis=0 |
\ | |
Interpolate(axes=(2, 3)) |
shape=[1, 3, 60, 160] |
\ /
Concat(axis=1)
shape=[1, 7, 60, 160]
1. Searches for Interpolate operation which output is connected to Concat (through identity operation or directly).
Interpolate -- [identity] --> Concat
2. Checks that Interpolate has positive axes parameter
3. Checks that Concat has positive axis (from attribute and N-input)
4. Checks that interpolation takes place over different dimensions than concatenation
5. Searches for Concat sources that are not connected to Interpolate operations
and checks that we have at least one such source (we could create a loop if we won't check)
6. If any of this checks are failed -- transformation doesn't do anything
7. Otherwise, we take the first Concat source from the (5) item.
Taking ShapeOf of this source and Gather'ing dimensions by the Interpolate::axes indices
we connect them to the second Interpolate input
This is how we get updated Interpolate second input that will fit the following Concat operation restrictions.
We perform this transformation of the FRONT phase for MO to be able to reshape this Interpolate layer too.
There is a similar transformation with less restrictions on the BACK phase.
"""
enabled = True
def run_after(self):
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
return [InterpolateNormalizer]
@staticmethod
def get_concat_axis(concat: Node):
# Concat axis may be stored as an attribute and as an input (TF) and this is not resolved yet
# TODO: should be removed after Concat operation normalization
assert concat.soft_get('type') == 'Concat'
if concat.has_valid('axis'):
return concat.axis
if concat.has_valid('N'):
axis_node = concat.in_port(concat.N).get_source().node
if axis_node.has_valid('value'):
return axis_node.value.item(0)
return None
@staticmethod
def get_single_output_destination_safely(node: Node, idx: int = 0):
"""
Checks if node has exactly one used output port and this output port is only used by one consumer
If the checks passed, function returns consumer_node, otherwise None
"""
connected_out_ports = [port for port in node.out_ports().values() if not port.disconnected()]
if len(connected_out_ports) == 1 and connected_out_ports[0].idx == idx:
dsts = node.out_port(idx).get_destinations()
if len(dsts) == 1:
return dsts[0].node
return None
@staticmethod
def get_single_input_source_safely(node: Node, idx: int = 0):
"""
Checks if node has exactly one used input port
If the check passed, function returns input_node otherwise None
"""
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
if len(connected_in_ports) == 1 and connected_in_ports[0].idx == idx:
return node.in_port(idx).get_source().node
return None
def get_non_interpolate_concat_sources(self, concat: Node):
"""
Traverses Concat input ports up to find which of them are not connected to Interpolate operations directly
or through identity operation sequence. Returns the list of Concat sources that satisfy the condition.
"""
assert concat.soft_get('type') == 'Concat'
sources = []
for in_port in concat.in_ports().values():
if in_port.disconnected():
continue
next_node = in_port.get_source().node
while next_node.soft_get('type') != 'Interpolate' and next_node.has_and_set('identity'):
node = self.get_single_input_source_safely(next_node)
if node is not None:
next_node = node
else:
break
if next_node.soft_get('type') != 'Interpolate':
sources.append(in_port.get_connection().get_source())
return sources
def make_interpolate_reshape_able(self, interpolate: Node, concat: Node):
assert interpolate.soft_get('type') == 'Interpolate'
assert concat.soft_get('type') == 'Concat'
interp_axes = interpolate.soft_get('axes', None)
interp_axes = interp_axes if interp_axes is None else int64_array(interp_axes)
concat_axis = self.get_concat_axis(concat)
if concat_axis is None or interp_axes is None \
or np.any(interp_axes < 0) or concat_axis < 0 \
or concat_axis in interp_axes:
# checks that interpolate axes and concat axis are valid and do not intersect
return
non_interp_concat_srcs = self.get_non_interpolate_concat_sources(concat)
if not len(non_interp_concat_srcs):
# there is no Concat input to take input from
return
graph = interpolate.graph
src = non_interp_concat_srcs[0]
shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
shape.in_port(0).connect(src)
gather = create_op_with_const_inputs(graph, Gather,
{1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0)},
{'name': shape.name + '/Gathered'}, input_node=shape)
interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def find_and_replace_pattern(self, graph: Graph):
for interpolate in graph.get_op_nodes(type='Interpolate'):
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const':
continue
# Interpolate could be connected to Concat through identity operations, skipping them
next_node = self.get_single_output_destination_safely(interpolate)
while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'):
node = self.get_single_output_destination_safely(next_node)
if node is not None:
next_node = node
else:
break
if next_node.soft_get('type') == 'Concat':
self.make_interpolate_reshape_able(interpolate, next_node)

View File

@@ -0,0 +1,139 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from argparse import Namespace
import numpy as np
from generator import generator, generate
from extensions.front.interpolate_reshape import InterpolateWithConcat
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \
connect_data
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 30, 40], {'type': 'Parameter'}),
**valued_const_with_data('out_shape', np.array([60, 160])),
**regular_op_with_shaped_data('interpolate', [1, 3, 60, 160],
{'type': 'Interpolate', 'axes': int64_array([2, 3])}),
**regular_op_with_shaped_data('identity_00', [1, 3, 60, 160], {'identity': True}),
**regular_op_with_shaped_data('identity_01', [1, 3, 60, 160], {'identity': True}),
**regular_op_with_shaped_data('shape', [4], {'type': 'ShapeOf'}),
**valued_const_with_data('indices', np.array([2, 3])),
**valued_const_with_data('axis', np.array(0)),
**regular_op_with_shaped_data('gather', [2], {'type': 'Gather'}),
**regular_op_with_shaped_data('placeholder_1', [1, 3, 60, 160], {'type': 'Parameter'}),
**regular_op_with_shaped_data('identity_10', [1, 3, 60, 160], {'identity': True}),
**regular_op_with_shaped_data('identity_11', [1, 3, 60, 160], {'identity': True}),
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1}),
**result(),
}
@generator
class TestInterpolateConcat(unittest.TestCase):
def test_interpolate_concat_reshape_graph_comparison(self):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
InterpolateWithConcat().find_and_replace_pattern(graph)
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('placeholder_1', 'shape'),
*connect('shape', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect_data('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_interpolate_identity_concat_reshape_graph_comparison(self):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', 'identity_00'),
*connect('identity_00', 'identity_01'),
*connect('identity_01', '0:concat'),
*connect('placeholder_1', 'identity_10'),
*connect('identity_10', 'identity_11'),
*connect('identity_11', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
InterpolateWithConcat().find_and_replace_pattern(graph)
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect_data('identity_11', 'shape'),
*connect('shape', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', '1:interpolate'),
*connect('interpolate', 'identity_00'),
*connect('identity_00', 'identity_01'),
*connect('identity_01', '0:concat'),
*connect('placeholder_1', 'identity_10'),
*connect('identity_10', 'identity_11'),
*connect('identity_11', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
@generate(*[
{'concat': {'axis': None}},
{'concat': {'axis': 2}},
{'concat': {'axis': -1}},
{'interpolate': {'axes': None}},
{'interpolate': {'axes': np.array([1])}},
{'interpolate': {'axes': np.array([2, -1])}},
])
def test_negative_axes_conditions(self, update_attrs):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], update_attributes=update_attrs, nodes_with_edges_only=True)
InterpolateWithConcat().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], update_attributes=update_attrs, nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@@ -25,8 +25,6 @@ class ResizeBilinearFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
mapping_rule = {
'pads_begin': 0,
'pads_end': 0,
'align_corners': int(node.pb.attr['align_corners'].b),
'mode': 'linear',
'axes': int64_array([1, 2]),