Support ONNX Clamp-11 (#538)

This commit is contained in:
Maxim Vafin 2020-05-25 19:59:07 +03:00 committed by GitHub
parent 04bb8ab51d
commit 8c8629a4af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 356 additions and 25 deletions

View File

@ -12,6 +12,7 @@ extensions/back/ActivationsNormalizer.py
extensions/back/AvgPool.py
extensions/back/blob_normalizer.py
extensions/back/CellNormalizer.py
extensions/back/ClampNormalizer.py
extensions/back/compress_quantized_weights.py
extensions/back/ConvolutionNormalizer.py
extensions/back/CorrectName.py
@ -72,6 +73,7 @@ extensions/back/UselessConcatRemoval.py
extensions/front/__init__.py
extensions/front/ArgMaxSqueeze.py
extensions/front/ATenToEmbeddingBag.py
extensions/front/AttributedClampNormalizer.py
extensions/front/AttributedGatherNormalizer.py
extensions/front/AttributedPadToPad.py
extensions/front/binary_quantize_normalization.py

View File

@ -0,0 +1,71 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.elementwise import Minimum, Maximum
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph, rename_node
from mo.ops.clamp import AttributedClamp
class ClampNormalizer(BackReplacementPattern):
"""
Replaces Clamp with `min` and `max` as inputs with AttributedClamp with `min` and `max` as attributes.
"""
enabled = True
force_clean_up = True
def pattern(self):
return dict(
nodes=[('clamp', dict(op='Clamp'))],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
clamp = match['clamp']
name = clamp.soft_get('name', clamp.id)
min_value = max_value = None
port_1_exist = clamp.has_port('in', 1) and not clamp.in_port(1).disconnected()
port_2_exist = clamp.has_port('in', 2) and not clamp.in_port(2).disconnected()
if port_1_exist and clamp.in_port(1).get_source().node.soft_get('type') == 'Const':
min_value = clamp.in_port(1).data.get_value()
if port_2_exist and clamp.in_port(2).get_source().node.soft_get('type') == 'Const':
max_value = clamp.in_port(2).data.get_value()
rename_node(clamp, name + '/TBR')
if min_value is None or max_value is None:
max_node = min_node = None
if port_1_exist:
max_node = Maximum(graph, {}).create_node()
clamp.in_port(0).get_connection().set_destination(max_node.in_port(0))
clamp.in_port(1).get_connection().set_destination(max_node.in_port(1))
clamp.out_port(0).get_connection().set_source(max_node.out_port(0))
if port_2_exist:
min_node = Minimum(graph, {}).create_node()
if max_node is not None:
max_node.out_port(0).get_connection().set_source(min_node.out_port(0))
max_node.out_port(0).connect(min_node.in_port(0))
else:
clamp.in_port(0).get_connection().set_destination(min_node.in_port(0))
clamp.out_port(0).get_connection().set_source(min_node.out_port(0))
clamp.in_port(2).get_connection().set_destination(min_node.in_port(1))
assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used'
rename_node(max_node if min_node is None else min_node, name)
else:
a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node()
rename_node(a_clamp, name)
clamp.in_port(0).get_connection().set_destination(a_clamp.in_port(0))
clamp.out_port(0).get_connection().set_source(a_clamp.out_port(0))

View File

@ -0,0 +1,118 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.back.ClampNormalizer import ClampNormalizer
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, connect
class AttributedClampNormalizerTests(unittest.TestCase):
def test_2_inputs(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
**regular_op_with_shaped_data('clamp', [1, 3, 20, 20],
{'type': 'Clamp', 'op': 'AttributedClamp', 'min': -3.5, 'max': 3.5}),
**valued_const_with_data('min', np.array(-3.5)),
**valued_const_with_data('max', np.array(3.5)),
**result('result'),
}
edges = [*connect('placeholder', '0:a_clamp'),
*connect('min', '1:a_clamp'),
*connect('max', '2:a_clamp'),
*connect('a_clamp', 'result'),
]
graph = build_graph(nodes, edges)
ClampNormalizer().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes, [*connect('placeholder', '0:clamp'), *connect('clamp', 'result')])
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)
def test_all_dynamic_inputs(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('min', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('max', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
**regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}),
**regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}),
**result('result'),
}
edges = [*connect('placeholder', '0:a_clamp'),
*connect('min', '1:a_clamp'),
*connect('max', '2:a_clamp'),
*connect('a_clamp', 'result'),
]
graph = build_graph(nodes, edges)
ClampNormalizer().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'),
*connect('min', '1:maximum'),
*connect('maximum', '0:minimum'),
*connect('max', '1:minimum'),
*connect('minimum', 'result')
])
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)
def test_no_2nd_input(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
**regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}),
**valued_const_with_data('min', np.array(-3.5)),
**result('result'),
}
edges = [*connect('placeholder', '0:a_clamp'),
*connect('min', '1:a_clamp'),
*connect('a_clamp', 'result'),
]
graph = build_graph(nodes, edges)
ClampNormalizer().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'),
*connect('min', '1:maximum'),
*connect('maximum', 'result')
])
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)
def test_no_1st_input(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
**regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}),
**valued_const_with_data('max', np.array(3.5)),
**result('result'),
}
edges = [*connect('placeholder', '0:a_clamp'),
*connect('max', '2:a_clamp'),
*connect('a_clamp', 'result'),
]
graph = build_graph(nodes, edges)
ClampNormalizer().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes, [*connect('placeholder', '0:minimum'),
*connect('max', '1:minimum'),
*connect('minimum', 'result')
])
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)

View File

@ -0,0 +1,46 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_node
from mo.ops.clamp import Clamp
class AttributedClampNormalizer(FrontReplacementPattern):
"""
This transformation converts AttributedClamp operation (min/max are specified as attribute) to Clamp
operation.
"""
enabled = True
def find_and_replace_pattern(self, graph: Graph):
for attr_clamp in graph.get_op_nodes(op='AttributedClamp'):
original_name = attr_clamp.soft_get('name', attr_clamp.id)
rename_node(attr_clamp, original_name + '/TBR')
min_value = attr_clamp.soft_get('min', np.finfo(np.float32).min)
max_value = attr_clamp.soft_get('max', np.finfo(np.float32).max)
new_clamp = create_op_with_const_inputs(graph, Clamp,
{1: np.array(min_value, dtype=np.float32),
2: np.array(max_value, dtype=np.float32)},
{'name': original_name})
rename_node(new_clamp, original_name)
attr_clamp.in_port(0).get_connection().set_destination(new_clamp.in_port(0))
attr_clamp.out_port(0).get_connection().set_source(new_clamp.out_port(0))
graph.remove_node(attr_clamp.id)

View File

@ -0,0 +1,62 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, const
nodes_attributes = {
'placeholder': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'attr_clamp': {'type': 'Clamp', 'kind': 'op', 'op': 'AttributedClamp', 'name': 'attr_clamp',
'min': np.array(-3.5, dtype=np.float32), 'max': np.array(3.5, dtype=np.float32)},
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},
# new Clamp layer and inputs
'clamp': {'type': None, 'kind': 'op', 'op': 'Clamp'},
**const('min', np.array(-3.5, dtype=np.float32)),
**const('max', np.array(3.5, dtype=np.float32)),
}
class AttributedClampNormalizerTest(unittest.TestCase):
def test_1(self):
graph = build_graph(nodes_attributes,
[('placeholder', 'attr_clamp', {'in': 0, 'out': 0}),
('attr_clamp', 'result', {'in': 0, 'out': 0}),
],
{}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder', 'clamp', {'in': 0, 'out': 0}),
('min', 'clamp', {'in': 1, 'out': 0}),
('max', 'clamp', {'in': 2, 'out': 0}),
('clamp', 'result')
],
{}, nodes_with_edges_only=True)
graph.graph['layout'] = 'NCHW'
graph.stage = 'front'
replacer = AttributedClampNormalizer()
replacer.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='Clamp')[0]]['name'] == 'attr_clamp')

View File

@ -22,13 +22,14 @@ from extensions.ops.split import Split
from mo.front.caffe.extractors.utils import input_as_const
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph, Port
from mo.ops.assign import Assign
from mo.ops.broadcast import Broadcast
from mo.ops.clamp import Clamp
from mo.ops.crop import Crop
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.crop import Crop
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.ops.scale_shift import ScaleShiftOp
@ -238,10 +239,10 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
# (7)Eltwise(sum) -> Clamp
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
'max': node.clip_value,
'min': -node.clip_value}).create_node()
join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
join_forget_clamp = create_op_with_const_inputs(graph, Clamp, {1: np.array(-node.clip_value, dtype=np.float32),
2: np.array(node.clip_value, dtype=np.float32)},
{'name': 'join_forget_clamp'},
join_forget_remember_sum)
#
# Clamp -> (2)Memory(state)
next_lstm_state = Assign(graph, {'name': 'next_lstm_state',

View File

@ -16,7 +16,7 @@
from mo.front.extractor import FrontExtractorOp
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from mo.graph.graph import Node
from mo.ops.clamp import Clamp
from mo.ops.clamp import AttributedClamp
class ClipExt(FrontExtractorOp):
@ -27,5 +27,5 @@ class ClipExt(FrontExtractorOp):
def extract(cls, node: Node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
Clamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None),})
AttributedClamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None)})
return cls.enabled

View File

@ -38,7 +38,7 @@ class CumSumFrontReplacer(FrontReplacementOp):
node.in_port(0).get_connection().set_destination(cumsum_node.in_port(0))
if node.has_valid('mx_out_type') and node['mx_out_type'] is not None:
rename_node(node=cumsum_node, name=name + '/Clamp')
rename_node(node=cumsum_node, name=name + '/CumSum')
convert = Cast(graph, {'name': name, 'dst_type': node['mx_out_type']}).create_node()
rename_node(convert, name)
cumsum_node.out_port(0).connect(convert.in_port(0))

View File

@ -13,10 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.ops.clamp import Clamp
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
from mo.ops.clamp import Clamp, AttributedClamp
class ClipFrontExtractor(FrontExtractorOp):
@ -25,9 +26,12 @@ class ClipFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
attrs = {
'min': onnx_attr(node, 'min', 'f', -3.4028234663852886e+38),
'max': onnx_attr(node, 'max', 'f', 3.4028234663852886e+38),
}
Clamp.update_node_stat(node, attrs)
if get_onnx_opset_version(node) < 11:
attrs = {
'min': onnx_attr(node, 'min', 'f', np.finfo(np.float32).min),
'max': onnx_attr(node, 'max', 'f', np.finfo(np.float32).max),
}
AttributedClamp.update_node_stat(node, attrs)
else:
Clamp.update_node_stat(node)
return cls.enabled

View File

@ -63,6 +63,10 @@ class ONNXLoader(Loader):
graph.graph['layout'] = 'NCHW'
graph.graph['fw'] = 'onnx'
graph.graph['feature_dim'] = 1
if hasattr(model_proto, 'opset_import'):
graph.graph['fw_opset_version'] = model_proto.opset_import[0].version # pylint: disable=no-member
else:
graph.graph['fw_opset_version'] = None
graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))

View File

@ -120,7 +120,11 @@ class ClampQuantizeMark(MiddleReplacementPattern):
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
clamp = match['clamp']
quantize = match['quantize']
clamp_min, clamp_max = clamp['min'], clamp['max']
clamp_min = clamp.in_port(1).data.get_value()
clamp_max = clamp.in_port(2).data.get_value()
if clamp_min is None or clamp_max is None:
log.debug('ReluQuantizeFuse: cannot fuse because Clamp op has dynamic input on the 1st or 2nd port')
return
if not clamp.has_valid('quantized_to_fuse_count'):
clamp['quantized_to_fuse_count'] = 0

View File

@ -18,7 +18,7 @@ import numpy as np
from mo.front.common.partial_infer.eltwise import eltwise_infer
from mo.graph.graph import Graph, Node
from mo.ops.clamp import Clamp
from mo.ops.clamp import AttributedClamp
from mo.ops.op import Op
activation_ops = ['Sigmoid', 'Tanh', 'ReLU6', 'Exp', 'Elu', 'LogicalNot', 'Floor', 'Ceiling']
@ -95,7 +95,7 @@ class Atan(Activation):
operation = staticmethod(lambda x: np.arctan(x))
class ReLU6(Clamp):
class ReLU6(AttributedClamp):
op = 'ReLU6'
def __init__(self, graph: Graph, attrs: dict):

View File

@ -57,6 +57,10 @@ def get_onnx_autopad(auto_pad):
return auto_pad
def get_onnx_opset_version(node: Node):
return node.graph.graph.get('fw_opset_version', 0)
def get_onnx_datatype_as_numpy(value):
datatype_to_numpy = {
1: np.float32,

View File

@ -19,13 +19,13 @@ from mo.graph.graph import Graph
from mo.ops.op import Op
class Clamp(Op):
op = 'Clamp'
class AttributedClamp(Op):
op = 'AttributedClamp'
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': __class__.op,
'op': __class__.op,
'type': 'Clamp',
'op': self.op,
'version': 'opset1',
'infer': copy_shape_infer,
'in_ports_count': 1,
@ -37,3 +37,16 @@ class Clamp(Op):
'max',
'min'
]
class Clamp(Op):
op = 'Clamp'
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': None,
'op': self.op,
'infer': copy_shape_infer,
'in_ports_count': 3,
'out_ports_count': 1,
}, attrs)

View File

@ -19,7 +19,7 @@ import unittest
import numpy as np
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.ops.clamp import Clamp
from mo.ops.clamp import AttributedClamp
from mo.utils.unittest.graph import build_graph
@ -41,7 +41,7 @@ class TestClampOp(unittest.TestCase):
('node_1', 'clamp_node'),
('clamp_node', 'node_3')
])
clamp_node = Clamp(graph, self.nodes_attributes['clamp_node']).add_node()
clamp_node = AttributedClamp(graph, self.nodes_attributes['clamp_node']).add_node()
self.assertEqual(clamp_node.type, 'Clamp')
self.assertEqual(clamp_node.op, 'Clamp')
self.assertEqual(clamp_node.op, 'AttributedClamp')
self.assertEqual(clamp_node.infer, copy_shape_infer)

View File

@ -28,6 +28,7 @@ from extensions.ops.scatter import Scatter
from extensions.ops.split import Split, VariadicSplit
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.ops.clamp import AttributedClamp
from mo.ops.convolution import Convolution
from mo.ops.deconvolution import Deconvolution
from mo.ops.op import Op
@ -53,6 +54,7 @@ custom_ops = {
'Split': Split,
'Subtract': Sub,
'VariadicSplit': VariadicSplit,
'Clamp': AttributedClamp,
}