[ MO ] Hard-coded Interpolate followed by concat reshape-ability fixing (#818)
This commit is contained in:
committed by
GitHub
parent
3490b985dd
commit
f40338ff4b
@@ -128,6 +128,7 @@ extensions/front/global_pooling_to_reduce.py
|
||||
extensions/front/image_scaler.py
|
||||
extensions/front/input_cut.py
|
||||
extensions/front/instance_normalization.py
|
||||
extensions/front/interpolate_reshape.py
|
||||
extensions/front/InterpolateNormalizer.py
|
||||
extensions/front/kaldi/__init__.py
|
||||
extensions/front/kaldi/add_permute_after_convolution.py
|
||||
|
||||
181
model-optimizer/extensions/front/interpolate_reshape.py
Normal file
181
model-optimizer/extensions/front/interpolate_reshape.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.gather import Gather
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.shape import Shape
|
||||
|
||||
|
||||
class InterpolateWithConcat(FrontReplacementPattern):
|
||||
"""
|
||||
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph using the following Concat inputs
|
||||
|
||||
BEFORE:
|
||||
input Const
|
||||
shape=[1, 3, 30, 40] value=[60, 160]
|
||||
\ /
|
||||
Interpolate(axes=(2, 3)) input_1
|
||||
shape=[1, 3, 60, 160] shape=[1, 4, 60, 160]
|
||||
\ /
|
||||
Concat(axis=1)
|
||||
shape=[1, 7, 60, 160]
|
||||
AFTER:
|
||||
input
|
||||
shape=[1, 3, 30, 40] input_1
|
||||
| shape=[1, 4, 60, 160]
|
||||
| / |
|
||||
| ShapeOf |
|
||||
| | |
|
||||
| Gather |
|
||||
| indices=(2, 3); axis=0 |
|
||||
\ | |
|
||||
Interpolate(axes=(2, 3)) |
|
||||
shape=[1, 3, 60, 160] |
|
||||
\ /
|
||||
Concat(axis=1)
|
||||
shape=[1, 7, 60, 160]
|
||||
|
||||
1. Searches for Interpolate operation which output is connected to Concat (through identity operation or directly).
|
||||
Interpolate -- [identity] --> Concat
|
||||
2. Checks that Interpolate has positive axes parameter
|
||||
3. Checks that Concat has positive axis (from attribute and N-input)
|
||||
4. Checks that interpolation takes place over different dimensions than concatenation
|
||||
5. Searches for Concat sources that are not connected to Interpolate operations
|
||||
and checks that we have at least one such source (we could create a loop if we won't check)
|
||||
6. If any of this checks are failed -- transformation doesn't do anything
|
||||
7. Otherwise, we take the first Concat source from the (5) item.
|
||||
Taking ShapeOf of this source and Gather'ing dimensions by the Interpolate::axes indices
|
||||
we connect them to the second Interpolate input
|
||||
|
||||
This is how we get updated Interpolate second input that will fit the following Concat operation restrictions.
|
||||
|
||||
|
||||
We perform this transformation of the FRONT phase for MO to be able to reshape this Interpolate layer too.
|
||||
There is a similar transformation with less restrictions on the BACK phase.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
|
||||
return [InterpolateNormalizer]
|
||||
|
||||
@staticmethod
|
||||
def get_concat_axis(concat: Node):
|
||||
# Concat axis may be stored as an attribute and as an input (TF) and this is not resolved yet
|
||||
# TODO: should be removed after Concat operation normalization
|
||||
assert concat.soft_get('type') == 'Concat'
|
||||
if concat.has_valid('axis'):
|
||||
return concat.axis
|
||||
if concat.has_valid('N'):
|
||||
axis_node = concat.in_port(concat.N).get_source().node
|
||||
if axis_node.has_valid('value'):
|
||||
return axis_node.value.item(0)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_single_output_destination_safely(node: Node, idx: int = 0):
|
||||
"""
|
||||
Checks if node has exactly one used output port and this output port is only used by one consumer
|
||||
If the checks passed, function returns consumer_node, otherwise None
|
||||
"""
|
||||
connected_out_ports = [port for port in node.out_ports().values() if not port.disconnected()]
|
||||
if len(connected_out_ports) == 1 and connected_out_ports[0].idx == idx:
|
||||
dsts = node.out_port(idx).get_destinations()
|
||||
if len(dsts) == 1:
|
||||
return dsts[0].node
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_single_input_source_safely(node: Node, idx: int = 0):
|
||||
"""
|
||||
Checks if node has exactly one used input port
|
||||
If the check passed, function returns input_node otherwise None
|
||||
"""
|
||||
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
if len(connected_in_ports) == 1 and connected_in_ports[0].idx == idx:
|
||||
return node.in_port(idx).get_source().node
|
||||
return None
|
||||
|
||||
def get_non_interpolate_concat_sources(self, concat: Node):
|
||||
"""
|
||||
Traverses Concat input ports up to find which of them are not connected to Interpolate operations directly
|
||||
or through identity operation sequence. Returns the list of Concat sources that satisfy the condition.
|
||||
"""
|
||||
assert concat.soft_get('type') == 'Concat'
|
||||
sources = []
|
||||
|
||||
for in_port in concat.in_ports().values():
|
||||
if in_port.disconnected():
|
||||
continue
|
||||
next_node = in_port.get_source().node
|
||||
while next_node.soft_get('type') != 'Interpolate' and next_node.has_and_set('identity'):
|
||||
node = self.get_single_input_source_safely(next_node)
|
||||
if node is not None:
|
||||
next_node = node
|
||||
else:
|
||||
break
|
||||
if next_node.soft_get('type') != 'Interpolate':
|
||||
sources.append(in_port.get_connection().get_source())
|
||||
return sources
|
||||
|
||||
def make_interpolate_reshape_able(self, interpolate: Node, concat: Node):
|
||||
assert interpolate.soft_get('type') == 'Interpolate'
|
||||
assert concat.soft_get('type') == 'Concat'
|
||||
|
||||
interp_axes = interpolate.soft_get('axes', None)
|
||||
interp_axes = interp_axes if interp_axes is None else int64_array(interp_axes)
|
||||
concat_axis = self.get_concat_axis(concat)
|
||||
|
||||
if concat_axis is None or interp_axes is None \
|
||||
or np.any(interp_axes < 0) or concat_axis < 0 \
|
||||
or concat_axis in interp_axes:
|
||||
# checks that interpolate axes and concat axis are valid and do not intersect
|
||||
return
|
||||
|
||||
non_interp_concat_srcs = self.get_non_interpolate_concat_sources(concat)
|
||||
if not len(non_interp_concat_srcs):
|
||||
# there is no Concat input to take input from
|
||||
return
|
||||
|
||||
graph = interpolate.graph
|
||||
src = non_interp_concat_srcs[0]
|
||||
|
||||
shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
|
||||
shape.in_port(0).connect(src)
|
||||
gather = create_op_with_const_inputs(graph, Gather,
|
||||
{1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0)},
|
||||
{'name': shape.name + '/Gathered'}, input_node=shape)
|
||||
interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for interpolate in graph.get_op_nodes(type='Interpolate'):
|
||||
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const':
|
||||
continue
|
||||
|
||||
# Interpolate could be connected to Concat through identity operations, skipping them
|
||||
next_node = self.get_single_output_destination_safely(interpolate)
|
||||
while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'):
|
||||
node = self.get_single_output_destination_safely(next_node)
|
||||
if node is not None:
|
||||
next_node = node
|
||||
else:
|
||||
break
|
||||
if next_node.soft_get('type') == 'Concat':
|
||||
self.make_interpolate_reshape_able(interpolate, next_node)
|
||||
139
model-optimizer/extensions/front/interpolate_reshape_test.py
Normal file
139
model-optimizer/extensions/front/interpolate_reshape_test.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.front.interpolate_reshape import InterpolateWithConcat
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \
|
||||
connect_data
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 30, 40], {'type': 'Parameter'}),
|
||||
**valued_const_with_data('out_shape', np.array([60, 160])),
|
||||
|
||||
**regular_op_with_shaped_data('interpolate', [1, 3, 60, 160],
|
||||
{'type': 'Interpolate', 'axes': int64_array([2, 3])}),
|
||||
**regular_op_with_shaped_data('identity_00', [1, 3, 60, 160], {'identity': True}),
|
||||
**regular_op_with_shaped_data('identity_01', [1, 3, 60, 160], {'identity': True}),
|
||||
|
||||
**regular_op_with_shaped_data('shape', [4], {'type': 'ShapeOf'}),
|
||||
**valued_const_with_data('indices', np.array([2, 3])),
|
||||
**valued_const_with_data('axis', np.array(0)),
|
||||
**regular_op_with_shaped_data('gather', [2], {'type': 'Gather'}),
|
||||
|
||||
**regular_op_with_shaped_data('placeholder_1', [1, 3, 60, 160], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('identity_10', [1, 3, 60, 160], {'identity': True}),
|
||||
**regular_op_with_shaped_data('identity_11', [1, 3, 60, 160], {'identity': True}),
|
||||
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1}),
|
||||
|
||||
**result(),
|
||||
}
|
||||
|
||||
|
||||
@generator
|
||||
class TestInterpolateConcat(unittest.TestCase):
|
||||
def test_interpolate_concat_reshape_graph_comparison(self):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', '0:concat'),
|
||||
*connect('placeholder_1', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
InterpolateWithConcat().find_and_replace_pattern(graph)
|
||||
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
|
||||
graph.clean_up()
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('placeholder_1', 'shape'),
|
||||
*connect('shape', '0:gather'),
|
||||
*connect('indices', '1:gather'),
|
||||
*connect('axis', '2:gather'),
|
||||
*connect('gather', '1:interpolate'),
|
||||
*connect('interpolate', '0:concat'),
|
||||
*connect_data('placeholder_1', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_interpolate_identity_concat_reshape_graph_comparison(self):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', 'identity_00'),
|
||||
*connect('identity_00', 'identity_01'),
|
||||
*connect('identity_01', '0:concat'),
|
||||
*connect('placeholder_1', 'identity_10'),
|
||||
*connect('identity_10', 'identity_11'),
|
||||
*connect('identity_11', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
InterpolateWithConcat().find_and_replace_pattern(graph)
|
||||
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
|
||||
graph.clean_up()
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect_data('identity_11', 'shape'),
|
||||
*connect('shape', '0:gather'),
|
||||
*connect('indices', '1:gather'),
|
||||
*connect('axis', '2:gather'),
|
||||
*connect('gather', '1:interpolate'),
|
||||
*connect('interpolate', 'identity_00'),
|
||||
*connect('identity_00', 'identity_01'),
|
||||
*connect('identity_01', '0:concat'),
|
||||
*connect('placeholder_1', 'identity_10'),
|
||||
*connect('identity_10', 'identity_11'),
|
||||
*connect('identity_11', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
@generate(*[
|
||||
{'concat': {'axis': None}},
|
||||
{'concat': {'axis': 2}},
|
||||
{'concat': {'axis': -1}},
|
||||
{'interpolate': {'axes': None}},
|
||||
{'interpolate': {'axes': np.array([1])}},
|
||||
{'interpolate': {'axes': np.array([2, -1])}},
|
||||
])
|
||||
def test_negative_axes_conditions(self, update_attrs):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', '0:concat'),
|
||||
*connect('placeholder_1', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], update_attributes=update_attrs, nodes_with_edges_only=True)
|
||||
InterpolateWithConcat().find_and_replace_pattern(graph)
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', '0:concat'),
|
||||
*connect('placeholder_1', '1:concat'),
|
||||
*connect('concat', 'output'),
|
||||
], update_attributes=update_attrs, nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
@@ -25,8 +25,6 @@ class ResizeBilinearFrontExtractor(FrontExtractorOp):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
mapping_rule = {
|
||||
'pads_begin': 0,
|
||||
'pads_end': 0,
|
||||
'align_corners': int(node.pb.attr['align_corners'].b),
|
||||
'mode': 'linear',
|
||||
'axes': int64_array([1, 2]),
|
||||
|
||||
Reference in New Issue
Block a user