Fix OneHot transformation for Bert Squad opset 10 (#954)
* Add transformation for squeezing depth input for ONNX OneHot operation because from some TF models it has shape [1] instead of []
This commit is contained in:
@@ -238,6 +238,7 @@ extensions/front/mxnet/where_ext.py
|
||||
extensions/front/mxnet/yolo_v3_mobilenet1_voc.json
|
||||
extensions/front/mxnet/zeros_ext.py
|
||||
extensions/front/no_op_eraser.py
|
||||
extensions/front/OneHotDepthNormalizer.py
|
||||
extensions/front/onnx/__init__.py
|
||||
extensions/front/onnx/activation_ext.py
|
||||
extensions/front/onnx/affine_ext.py
|
||||
|
||||
43
model-optimizer/extensions/front/OneHotDepthNormalizer.py
Normal file
43
model-optimizer/extensions/front/OneHotDepthNormalizer.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.reshape import Reshape
|
||||
|
||||
|
||||
class OneHotDepthNormalizer(FrontReplacementPattern):
|
||||
"""
|
||||
Transformation performs squeezeng one-element tensors on 1st input in OneHot into 0D scalars. This transformation
|
||||
allows to avoid problems with some models produced by tf2onnx which have 1D depth in OneHot.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('onehot', dict(kind='op', type='OneHot'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['onehot']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
reshape = create_op_with_const_inputs(graph, Reshape, {1: int64_array([])}, {'name': node_name + '/Reshape'})
|
||||
node.in_port(1).get_connection().insert_node(reshape)
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.OneHotDepthNormalizer import OneHotDepthNormalizer
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, \
|
||||
regular_op, const
|
||||
|
||||
|
||||
class OneHotDepthNormalizerTest(unittest.TestCase):
|
||||
def test(self):
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
**const('depth', int64_array([2])),
|
||||
**regular_op('onehot', {'type': 'OneHot', 'kind': 'op', 'op': 'OneHot'}),
|
||||
|
||||
**regular_op('reshape', {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'}),
|
||||
**const('reshape_dims', int64_array([])),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [('input', 'onehot'),
|
||||
('depth', 'onehot'),
|
||||
('onehot', 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph.stage = 'front'
|
||||
|
||||
edges_ref = [('input', 'onehot'),
|
||||
('depth', 'reshape'),
|
||||
('reshape_dims', 'reshape'),
|
||||
('reshape', 'onehot'),
|
||||
('onehot', 'result'),
|
||||
]
|
||||
|
||||
graph_ref = build_graph(nodes, edges_ref)
|
||||
|
||||
OneHotDepthNormalizer().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
@@ -22,6 +22,7 @@ from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.squeeze import Squeeze
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
class SqueezeAxis(FrontReplacementOp):
|
||||
@@ -40,11 +41,17 @@ class SqueezeAxis(FrontReplacementOp):
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(squeeze_axis=True):
|
||||
name = node.soft_get('name', node.id)
|
||||
assert node.has_valid('axis'), 'Unknown axis to squeeze for node {}'.format(name)
|
||||
for out_port in node.out_ports().values():
|
||||
squeeze_node = create_op_with_const_inputs(graph, Squeeze, {1: np.array(node.axis)},
|
||||
{'name': name + '/Squeeze_'})
|
||||
out_port.get_connection().insert_node(squeeze_node)
|
||||
if node.has_valid('axis'):
|
||||
squeeze_node = create_op_with_const_inputs(graph, Squeeze, {1: np.array(node.axis)},
|
||||
{'name': name + '/Squeeze_'})
|
||||
out_port.get_connection().insert_node(squeeze_node)
|
||||
elif node.is_in_port_connected(1):
|
||||
squeeze_node = Squeeze(graph, {'name': name + '/Squeeze_'}).create_node()
|
||||
out_port.get_connection().insert_node(squeeze_node)
|
||||
node.in_port(1).get_connection().add_destination(squeeze_node.in_port(1))
|
||||
else:
|
||||
raise Error('Unknown axis to squeeze for node {}'.format(name))
|
||||
|
||||
|
||||
class SplitInputsReconnect(FrontReplacementSubgraph):
|
||||
|
||||
95
model-optimizer/extensions/front/split_normalizer_test.py
Normal file
95
model-optimizer/extensions/front/split_normalizer_test.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.split_normalizer import SqueezeAxis
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, const
|
||||
|
||||
nodes_attributes = {
|
||||
'placeholder': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'attr_split': {'type': None, 'kind': 'op', 'op': 'AttributedSplit', 'axis': 0, 'num_splits': 2,
|
||||
'squeeze_axis': True},
|
||||
'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2, 'squeeze_axis': True},
|
||||
**const('split_axis', int64_array(0)),
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat', 'axis': 0},
|
||||
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
|
||||
'squeeze1': {'type': 'Squeeze', 'kind': 'op', 'op': 'Squeeze'},
|
||||
'squeeze2': {'type': 'Squeeze', 'kind': 'op', 'op': 'Squeeze'},
|
||||
**const('squeeze1_axis', int64_array(0)),
|
||||
**const('squeeze2_axis', int64_array(0)),
|
||||
}
|
||||
|
||||
|
||||
class SqueezeAxisTest(unittest.TestCase):
|
||||
def test_attributed(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder', 'attr_split', {'in': 0, 'out': 0}),
|
||||
('attr_split', 'concat', {'in': 0, 'out': 0}),
|
||||
('attr_split', 'concat', {'in': 1, 'out': 1}),
|
||||
('concat', 'result', {'in': 0, 'out': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder', 'attr_split', {'in': 0, 'out': 0}),
|
||||
('attr_split', 'squeeze1', {'in': 0, 'out': 0}),
|
||||
('squeeze1_axis', 'squeeze1', {'in': 1, 'out': 0}),
|
||||
('attr_split', 'squeeze2', {'in': 0, 'out': 1}),
|
||||
('squeeze2_axis', 'squeeze2', {'in': 1, 'out': 0}),
|
||||
('squeeze1', 'concat', {'in': 0, 'out': 0}),
|
||||
('squeeze2', 'concat', {'in': 1, 'out': 0}),
|
||||
('concat', 'result', {'in': 0, 'out': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph.stage = 'front'
|
||||
|
||||
SqueezeAxis().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_split(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder', 'split', {'in': 0, 'out': 0}),
|
||||
('split_axis', 'split', {'in': 1, 'out': 0}),
|
||||
('split', 'concat', {'in': 0, 'out': 0}),
|
||||
('split', 'concat', {'in': 1, 'out': 1}),
|
||||
('concat', 'result', {'in': 0, 'out': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder', 'split', {'in': 0, 'out': 0}),
|
||||
('split_axis', 'split', {'in': 1, 'out': 0}),
|
||||
('split', 'squeeze1', {'in': 0, 'out': 0}),
|
||||
('split_axis', 'squeeze1', {'in': 1, 'out': 0}),
|
||||
('split', 'squeeze2', {'in': 0, 'out': 1}),
|
||||
('split_axis', 'squeeze2', {'in': 1, 'out': 0}),
|
||||
('squeeze1', 'concat', {'in': 0, 'out': 0}),
|
||||
('squeeze2', 'concat', {'in': 1, 'out': 0}),
|
||||
('concat', 'result', {'in': 0, 'out': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph.stage = 'front'
|
||||
|
||||
SqueezeAxis().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
Reference in New Issue
Block a user