[ MO Interpolate ] Fixing broken model reshape-ability (#619)

This commit is contained in:
Evgenya Stepyreva
2020-05-29 09:15:47 +03:00
committed by GitHub
parent 5cc8114322
commit e290b14ab1
3 changed files with 252 additions and 0 deletions

View File

@@ -0,0 +1,154 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.elementwise import Mul
from extensions.ops.gather import Gather
from mo.back.replacement import BackReplacementPattern
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.ops.shape import Shape
class InterpolateConcat(BackReplacementPattern):
"""
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph using the following Concat inputs
BEFORE:
input Const
shape=[1, 3, 30, 40] value=[60, 160]
\ /
Interpolate(axes=(2, 3)) input_1
shape=[1, 3, 60, 160] shape=[1, 4, 60, 160]
\ /
Concat(axis=1)
shape=[1, 7, 60, 160]
AFTER:
input
shape=[1, 3, 30, 40] input_1
| shape=[1, 4, 60, 160]
| / |
| ShapeOf |
| | |
| Gather |
| indices=(2, 3); axis=0 |
\ | |
Interpolate(axes=(2, 3)) |
shape=[1, 3, 60, 160] |
\ /
Concat(axis=1)
shape=[1, 7, 60, 160]
"""
enabled = True
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops]
force_shape_inference = True
id = 'reshape_interpolate_through_concat'
@staticmethod
def make_interpolate_reshapeable(interpolate, concat):
assert interpolate.soft_get('type') == 'Interpolate'
assert concat.soft_get('type') == 'Concat'
output_shape = interpolate.out_port(0).data.get_shape()
interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in interpolate.axes]
concat_axis = get_canonical_axis_index(output_shape, concat.axis)
if concat_axis in interp_axes:
return
concat_srcs = [port.get_source() for port in concat.in_ports().values()]
non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
if len(non_interp_concat_srcs) == 0:
return
graph = interpolate.graph
src = non_interp_concat_srcs[0]
shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
shape.in_port(0).connect(src)
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0)},
{'name': shape.name + '/Gathered'}, shape)
interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def find_and_replace_pattern(self, graph: Graph):
for interpolate in graph.get_op_nodes(type='Interpolate'):
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const':
continue
dsts = interpolate.out_port(0).get_destinations()
if len(dsts) == 1 and dsts[0].node.soft_get('type') == 'Concat':
self.make_interpolate_reshapeable(interpolate, dsts[0].node)
class InterpolateReshapeWA(BackReplacementPattern):
"""
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph.
WARNING: Could cause troubles if model has hard-coded Interpolate intentionally -- rare situation
BEFORE:
input Const
shape=[1, 3, 30, 40] value=[60, 160]
\ /
Interpolate(axes=(2, 3))
shape=[1, 3, 60, 160]
AFTER:
input
shape=[1, 3, 30, 40]
| \
| ShapeOf
| |
| Gather Const
| indices=(2, 3); axis=0 value=[2, 4]
| \ /
| Multiply
| /
Interpolate(axes=(2, 3))
shape=[1, 3, 60, 160]
"""
enabled = False
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops]
force_shape_inference = True
id = 'reshape_interpolate_wa'
def run_after(self):
return [InterpolateConcat]
@staticmethod
def make_interpolate_reshapeable(interpolate):
assert interpolate.soft_get('type') == 'Interpolate'
axes = interpolate.axes
input_shape = interpolate.in_port(0).data.get_shape()
output_shape = interpolate.out_port(0).data.get_shape()
if not np.all(np.remainder(output_shape, input_shape) == 0) and \
not np.all(np.remainder(input_shape, output_shape) == 0):
return
graph = interpolate.graph
name = interpolate.soft_get('name', interpolate.id)
shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
shape.in_port(0).connect(interpolate.in_port(0).get_source())
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)},
{'name': shape.name + '/Gathered'}, shape)
multipliers = output_shape[axes] / input_shape[axes]
mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather)
interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
def find_and_replace_pattern(self, graph: Graph):
for interpolate in graph.get_op_nodes(type='Interpolate'):
if interpolate.in_port(1).get_source().node.soft_get('type') == 'Const':
self.make_interpolate_reshapeable(interpolate)

View File

@@ -0,0 +1,97 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from argparse import Namespace
import numpy as np
from extensions.back.InterpolateReshape import InterpolateReshapeWA, InterpolateConcat
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \
connect_data
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 30, 40], {'type': 'Parameter'}),
**valued_const_with_data('out_shape', np.array([60, 160])),
**regular_op_with_shaped_data('interpolate', [1, 3, 60, 160], {'type': 'Interpolate', 'axes': [2, 3]}),
**regular_op_with_shaped_data('shape', [4], {'type': 'ShapeOf'}),
**valued_const_with_data('indices', np.array([2, 3])),
**valued_const_with_data('axis', np.array(0)),
**regular_op_with_shaped_data('gather', [2], {'type': 'Gather'}),
**valued_const_with_data('multiplier', np.array([2, 4])),
**regular_op_with_shaped_data('mul', [2], {'type': 'Multiply'}),
**regular_op_with_shaped_data('placeholder_1', [1, 3, 60, 160], {'type': 'Parameter'}),
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1}),
**result(),
}
class TestInterpolateReshapeWA(unittest.TestCase):
def test_interpolate_reshape_graph_comparison(self):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', 'output'),
], nodes_with_edges_only=True)
InterpolateReshapeWA().find_and_replace_pattern(graph)
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect_data('placeholder', 'shape'),
*connect('shape', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', '0:mul'),
*connect('multiplier', '1:mul'),
*connect('mul', '1:interpolate'),
*connect('interpolate', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
class TestInterpolateConcat(unittest.TestCase):
def test_interpolate_concat_reshape_graph_comparison(self):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
InterpolateConcat().find_and_replace_pattern(graph)
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('placeholder_1', 'shape'),
*connect('shape', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect_data('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)