[ MO Interpolate ] Fixing broken model reshape-ability (#619)
This commit is contained in:
committed by
GitHub
parent
5cc8114322
commit
e290b14ab1
154
model-optimizer/extensions/back/InterpolateReshape.py
Normal file
154
model-optimizer/extensions/back/InterpolateReshape.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.gather import Gather
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.caffe.extractors.utils import get_canonical_axis_index
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.shape import Shape
|
||||
|
||||
|
||||
class InterpolateConcat(BackReplacementPattern):
|
||||
"""
|
||||
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph using the following Concat inputs
|
||||
|
||||
BEFORE:
|
||||
input Const
|
||||
shape=[1, 3, 30, 40] value=[60, 160]
|
||||
\ /
|
||||
Interpolate(axes=(2, 3)) input_1
|
||||
shape=[1, 3, 60, 160] shape=[1, 4, 60, 160]
|
||||
\ /
|
||||
Concat(axis=1)
|
||||
shape=[1, 7, 60, 160]
|
||||
AFTER:
|
||||
input
|
||||
shape=[1, 3, 30, 40] input_1
|
||||
| shape=[1, 4, 60, 160]
|
||||
| / |
|
||||
| ShapeOf |
|
||||
| | |
|
||||
| Gather |
|
||||
| indices=(2, 3); axis=0 |
|
||||
\ | |
|
||||
Interpolate(axes=(2, 3)) |
|
||||
shape=[1, 3, 60, 160] |
|
||||
\ /
|
||||
Concat(axis=1)
|
||||
shape=[1, 7, 60, 160]
|
||||
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops]
|
||||
force_shape_inference = True
|
||||
id = 'reshape_interpolate_through_concat'
|
||||
|
||||
@staticmethod
|
||||
def make_interpolate_reshapeable(interpolate, concat):
|
||||
assert interpolate.soft_get('type') == 'Interpolate'
|
||||
assert concat.soft_get('type') == 'Concat'
|
||||
|
||||
output_shape = interpolate.out_port(0).data.get_shape()
|
||||
|
||||
interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in interpolate.axes]
|
||||
concat_axis = get_canonical_axis_index(output_shape, concat.axis)
|
||||
if concat_axis in interp_axes:
|
||||
return
|
||||
|
||||
concat_srcs = [port.get_source() for port in concat.in_ports().values()]
|
||||
non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
|
||||
if len(non_interp_concat_srcs) == 0:
|
||||
return
|
||||
|
||||
graph = interpolate.graph
|
||||
src = non_interp_concat_srcs[0]
|
||||
|
||||
shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
|
||||
shape.in_port(0).connect(src)
|
||||
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0)},
|
||||
{'name': shape.name + '/Gathered'}, shape)
|
||||
interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for interpolate in graph.get_op_nodes(type='Interpolate'):
|
||||
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const':
|
||||
continue
|
||||
dsts = interpolate.out_port(0).get_destinations()
|
||||
if len(dsts) == 1 and dsts[0].node.soft_get('type') == 'Concat':
|
||||
self.make_interpolate_reshapeable(interpolate, dsts[0].node)
|
||||
|
||||
|
||||
class InterpolateReshapeWA(BackReplacementPattern):
|
||||
"""
|
||||
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph.
|
||||
WARNING: Could cause troubles if model has hard-coded Interpolate intentionally -- rare situation
|
||||
|
||||
BEFORE:
|
||||
input Const
|
||||
shape=[1, 3, 30, 40] value=[60, 160]
|
||||
\ /
|
||||
Interpolate(axes=(2, 3))
|
||||
shape=[1, 3, 60, 160]
|
||||
|
||||
AFTER:
|
||||
input
|
||||
shape=[1, 3, 30, 40]
|
||||
| \
|
||||
| ShapeOf
|
||||
| |
|
||||
| Gather Const
|
||||
| indices=(2, 3); axis=0 value=[2, 4]
|
||||
| \ /
|
||||
| Multiply
|
||||
| /
|
||||
Interpolate(axes=(2, 3))
|
||||
shape=[1, 3, 60, 160]
|
||||
"""
|
||||
enabled = False
|
||||
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops]
|
||||
force_shape_inference = True
|
||||
id = 'reshape_interpolate_wa'
|
||||
|
||||
def run_after(self):
|
||||
return [InterpolateConcat]
|
||||
|
||||
@staticmethod
|
||||
def make_interpolate_reshapeable(interpolate):
|
||||
assert interpolate.soft_get('type') == 'Interpolate'
|
||||
axes = interpolate.axes
|
||||
input_shape = interpolate.in_port(0).data.get_shape()
|
||||
output_shape = interpolate.out_port(0).data.get_shape()
|
||||
if not np.all(np.remainder(output_shape, input_shape) == 0) and \
|
||||
not np.all(np.remainder(input_shape, output_shape) == 0):
|
||||
return
|
||||
graph = interpolate.graph
|
||||
name = interpolate.soft_get('name', interpolate.id)
|
||||
shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
|
||||
shape.in_port(0).connect(interpolate.in_port(0).get_source())
|
||||
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)},
|
||||
{'name': shape.name + '/Gathered'}, shape)
|
||||
multipliers = output_shape[axes] / input_shape[axes]
|
||||
mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather)
|
||||
interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for interpolate in graph.get_op_nodes(type='Interpolate'):
|
||||
if interpolate.in_port(1).get_source().node.soft_get('type') == 'Const':
|
||||
self.make_interpolate_reshapeable(interpolate)
|
||||
97
model-optimizer/extensions/back/InterpolateReshape_test.py
Normal file
97
model-optimizer/extensions/back/InterpolateReshape_test.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.InterpolateReshape import InterpolateReshapeWA, InterpolateConcat
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \
|
||||
connect_data
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 30, 40], {'type': 'Parameter'}),
|
||||
**valued_const_with_data('out_shape', np.array([60, 160])),
|
||||
|
||||
**regular_op_with_shaped_data('interpolate', [1, 3, 60, 160], {'type': 'Interpolate', 'axes': [2, 3]}),
|
||||
|
||||
**regular_op_with_shaped_data('shape', [4], {'type': 'ShapeOf'}),
|
||||
**valued_const_with_data('indices', np.array([2, 3])),
|
||||
**valued_const_with_data('axis', np.array(0)),
|
||||
**regular_op_with_shaped_data('gather', [2], {'type': 'Gather'}),
|
||||
|
||||
**valued_const_with_data('multiplier', np.array([2, 4])),
|
||||
**regular_op_with_shaped_data('mul', [2], {'type': 'Multiply'}),
|
||||
|
||||
**regular_op_with_shaped_data('placeholder_1', [1, 3, 60, 160], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1}),
|
||||
|
||||
**result(),
|
||||
}
|
||||
|
||||
|
||||
class TestInterpolateReshapeWA(unittest.TestCase):
|
||||
def test_interpolate_reshape_graph_comparison(self):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
InterpolateReshapeWA().find_and_replace_pattern(graph)
|
||||
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
|
||||
graph.clean_up()
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect_data('placeholder', 'shape'),
|
||||
*connect('shape', '0:gather'),
|
||||
*connect('indices', '1:gather'),
|
||||
*connect('axis', '2:gather'),
|
||||
*connect('gather', '0:mul'),
|
||||
*connect('multiplier', '1:mul'),
|
||||
*connect('mul', '1:interpolate'),
|
||||
*connect('interpolate', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
class TestInterpolateConcat(unittest.TestCase):
|
||||
def test_interpolate_concat_reshape_graph_comparison(self):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', '0:concat'),
|
||||
*connect('placeholder_1', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
InterpolateConcat().find_and_replace_pattern(graph)
|
||||
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
|
||||
graph.clean_up()
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('placeholder_1', 'shape'),
|
||||
*connect('shape', '0:gather'),
|
||||
*connect('indices', '1:gather'),
|
||||
*connect('axis', '2:gather'),
|
||||
*connect('gather', '1:interpolate'),
|
||||
*connect('interpolate', '0:concat'),
|
||||
*connect_data('placeholder_1', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
Reference in New Issue
Block a user