Fix OneHot transformation for Bert Squad opset 10 (#954)

* Add transformation for squeezing depth input for ONNX OneHot operation because from some TF models it has shape [1] instead of []
This commit is contained in:
Maxim Vafin
2020-06-22 18:58:07 +03:00
committed by GitHub
parent c9eb6ae62b
commit c9fc6f0531
5 changed files with 208 additions and 4 deletions

View File

@@ -238,6 +238,7 @@ extensions/front/mxnet/where_ext.py
extensions/front/mxnet/yolo_v3_mobilenet1_voc.json
extensions/front/mxnet/zeros_ext.py
extensions/front/no_op_eraser.py
extensions/front/OneHotDepthNormalizer.py
extensions/front/onnx/__init__.py
extensions/front/onnx/activation_ext.py
extensions/front/onnx/affine_ext.py

View File

@@ -0,0 +1,43 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.reshape import Reshape
class OneHotDepthNormalizer(FrontReplacementPattern):
"""
Transformation performs squeezeng one-element tensors on 1st input in OneHot into 0D scalars. This transformation
allows to avoid problems with some models produced by tf2onnx which have 1D depth in OneHot.
"""
enabled = True
def pattern(self):
return dict(
nodes=[
('onehot', dict(kind='op', type='OneHot'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['onehot']
node_name = node.soft_get('name', node.id)
reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([])}, {'name': node_name + '/Reshape'})
node.in_port(1).get_connection().insert_node(reshape)

View File

@@ -0,0 +1,58 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.OneHotDepthNormalizer import OneHotDepthNormalizer
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, \
regular_op, const
class OneHotDepthNormalizerTest(unittest.TestCase):
def test(self):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**const('depth', int64_array([2])),
**regular_op('onehot', {'type': 'OneHot', 'kind': 'op', 'op': 'OneHot'}),
**regular_op('reshape', {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'}),
**const('reshape_dims', int64_array([])),
**result('result'),
}
edges = [('input', 'onehot'),
('depth', 'onehot'),
('onehot', 'result'),
]
graph = build_graph(nodes, edges)
graph.graph['layout'] = 'NCHW'
graph.stage = 'front'
edges_ref = [('input', 'onehot'),
('depth', 'reshape'),
('reshape_dims', 'reshape'),
('reshape', 'onehot'),
('onehot', 'result'),
]
graph_ref = build_graph(nodes, edges_ref)
OneHotDepthNormalizer().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

View File

@@ -22,6 +22,7 @@ from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.squeeze import Squeeze
from mo.utils.error import Error
class SqueezeAxis(FrontReplacementOp):
@@ -40,11 +41,17 @@ class SqueezeAxis(FrontReplacementOp):
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(squeeze_axis=True):
name = node.soft_get('name', node.id)
assert node.has_valid('axis'), 'Unknown axis to squeeze for node {}'.format(name)
for out_port in node.out_ports().values():
squeeze_node = create_op_with_const_inputs(graph, Squeeze, {1: np.array(node.axis)},
{'name': name + '/Squeeze_'})
out_port.get_connection().insert_node(squeeze_node)
if node.has_valid('axis'):
squeeze_node = create_op_with_const_inputs(graph, Squeeze, {1: np.array(node.axis)},
{'name': name + '/Squeeze_'})
out_port.get_connection().insert_node(squeeze_node)
elif node.is_in_port_connected(1):
squeeze_node = Squeeze(graph, {'name': name + '/Squeeze_'}).create_node()
out_port.get_connection().insert_node(squeeze_node)
node.in_port(1).get_connection().add_destination(squeeze_node.in_port(1))
else:
raise Error('Unknown axis to squeeze for node {}'.format(name))
class SplitInputsReconnect(FrontReplacementSubgraph):

View File

@@ -0,0 +1,95 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.split_normalizer import SqueezeAxis
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, const
nodes_attributes = {
'placeholder': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'attr_split': {'type': None, 'kind': 'op', 'op': 'AttributedSplit', 'axis': 0, 'num_splits': 2,
'squeeze_axis': True},
'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2, 'squeeze_axis': True},
**const('split_axis', int64_array(0)),
'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat', 'axis': 0},
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},
'squeeze1': {'type': 'Squeeze', 'kind': 'op', 'op': 'Squeeze'},
'squeeze2': {'type': 'Squeeze', 'kind': 'op', 'op': 'Squeeze'},
**const('squeeze1_axis', int64_array(0)),
**const('squeeze2_axis', int64_array(0)),
}
class SqueezeAxisTest(unittest.TestCase):
def test_attributed(self):
graph = build_graph(nodes_attributes,
[('placeholder', 'attr_split', {'in': 0, 'out': 0}),
('attr_split', 'concat', {'in': 0, 'out': 0}),
('attr_split', 'concat', {'in': 1, 'out': 1}),
('concat', 'result', {'in': 0, 'out': 0}),
], nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder', 'attr_split', {'in': 0, 'out': 0}),
('attr_split', 'squeeze1', {'in': 0, 'out': 0}),
('squeeze1_axis', 'squeeze1', {'in': 1, 'out': 0}),
('attr_split', 'squeeze2', {'in': 0, 'out': 1}),
('squeeze2_axis', 'squeeze2', {'in': 1, 'out': 0}),
('squeeze1', 'concat', {'in': 0, 'out': 0}),
('squeeze2', 'concat', {'in': 1, 'out': 0}),
('concat', 'result', {'in': 0, 'out': 0}),
], nodes_with_edges_only=True)
graph.graph['layout'] = 'NCHW'
graph.stage = 'front'
SqueezeAxis().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_split(self):
graph = build_graph(nodes_attributes,
[('placeholder', 'split', {'in': 0, 'out': 0}),
('split_axis', 'split', {'in': 1, 'out': 0}),
('split', 'concat', {'in': 0, 'out': 0}),
('split', 'concat', {'in': 1, 'out': 1}),
('concat', 'result', {'in': 0, 'out': 0}),
], nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder', 'split', {'in': 0, 'out': 0}),
('split_axis', 'split', {'in': 1, 'out': 0}),
('split', 'squeeze1', {'in': 0, 'out': 0}),
('split_axis', 'squeeze1', {'in': 1, 'out': 0}),
('split', 'squeeze2', {'in': 0, 'out': 1}),
('split_axis', 'squeeze2', {'in': 1, 'out': 0}),
('squeeze1', 'concat', {'in': 0, 'out': 0}),
('squeeze2', 'concat', {'in': 1, 'out': 0}),
('concat', 'result', {'in': 0, 'out': 0}),
], nodes_with_edges_only=True)
graph.graph['layout'] = 'NCHW'
graph.stage = 'front'
SqueezeAxis().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)