From 8c8629a4afcfd21a1bbcaec244bb37a5737d4919 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 25 May 2020 19:59:07 +0300 Subject: [PATCH] Support ONNX Clamp-11 (#538) --- model-optimizer/automation/package_BOM.txt | 2 + .../extensions/back/ClampNormalizer.py | 71 +++++++++++ .../extensions/back/ClampNormalizer_test.py | 118 ++++++++++++++++++ .../front/AttributedClampNormalizer.py | 46 +++++++ .../front/AttributedClampNormalizer_test.py | 62 +++++++++ .../front/kaldi/replace_lstm_node_pattern.py | 11 +- .../extensions/front/mxnet/clip_ext.py | 4 +- .../extensions/front/mxnet/cumsum.py | 2 +- .../extensions/front/onnx/clip_ext.py | 18 +-- .../extensions/load/onnx/loader.py | 4 + .../extensions/middle/ReluQuantizeFuse.py | 6 +- .../extensions/ops/activation_ops.py | 4 +- .../mo/front/onnx/extractors/utils.py | 4 + model-optimizer/mo/ops/clamp.py | 21 +++- model-optimizer/mo/ops/clamp_test.py | 6 +- .../mo/utils/ir_reader/layer_to_class.py | 2 + 16 files changed, 356 insertions(+), 25 deletions(-) create mode 100644 model-optimizer/extensions/back/ClampNormalizer.py create mode 100644 model-optimizer/extensions/back/ClampNormalizer_test.py create mode 100644 model-optimizer/extensions/front/AttributedClampNormalizer.py create mode 100644 model-optimizer/extensions/front/AttributedClampNormalizer_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 104d86f2e92..dfbd2e6e7ed 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -12,6 +12,7 @@ extensions/back/ActivationsNormalizer.py extensions/back/AvgPool.py extensions/back/blob_normalizer.py extensions/back/CellNormalizer.py +extensions/back/ClampNormalizer.py extensions/back/compress_quantized_weights.py extensions/back/ConvolutionNormalizer.py extensions/back/CorrectName.py @@ -72,6 +73,7 @@ extensions/back/UselessConcatRemoval.py extensions/front/__init__.py extensions/front/ArgMaxSqueeze.py extensions/front/ATenToEmbeddingBag.py +extensions/front/AttributedClampNormalizer.py extensions/front/AttributedGatherNormalizer.py extensions/front/AttributedPadToPad.py extensions/front/binary_quantize_normalization.py diff --git a/model-optimizer/extensions/back/ClampNormalizer.py b/model-optimizer/extensions/back/ClampNormalizer.py new file mode 100644 index 00000000000..52fe3c73535 --- /dev/null +++ b/model-optimizer/extensions/back/ClampNormalizer.py @@ -0,0 +1,71 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from extensions.ops.elementwise import Minimum, Maximum +from mo.back.replacement import BackReplacementPattern +from mo.graph.graph import Graph, rename_node +from mo.ops.clamp import AttributedClamp + + +class ClampNormalizer(BackReplacementPattern): + """ + Replaces Clamp with `min` and `max` as inputs with AttributedClamp with `min` and `max` as attributes. + """ + enabled = True + force_clean_up = True + + def pattern(self): + return dict( + nodes=[('clamp', dict(op='Clamp'))], + edges=[] + ) + + def replace_pattern(self, graph: Graph, match: dict): + clamp = match['clamp'] + name = clamp.soft_get('name', clamp.id) + + min_value = max_value = None + port_1_exist = clamp.has_port('in', 1) and not clamp.in_port(1).disconnected() + port_2_exist = clamp.has_port('in', 2) and not clamp.in_port(2).disconnected() + if port_1_exist and clamp.in_port(1).get_source().node.soft_get('type') == 'Const': + min_value = clamp.in_port(1).data.get_value() + if port_2_exist and clamp.in_port(2).get_source().node.soft_get('type') == 'Const': + max_value = clamp.in_port(2).data.get_value() + + rename_node(clamp, name + '/TBR') + if min_value is None or max_value is None: + max_node = min_node = None + if port_1_exist: + max_node = Maximum(graph, {}).create_node() + clamp.in_port(0).get_connection().set_destination(max_node.in_port(0)) + clamp.in_port(1).get_connection().set_destination(max_node.in_port(1)) + clamp.out_port(0).get_connection().set_source(max_node.out_port(0)) + if port_2_exist: + min_node = Minimum(graph, {}).create_node() + if max_node is not None: + max_node.out_port(0).get_connection().set_source(min_node.out_port(0)) + max_node.out_port(0).connect(min_node.in_port(0)) + else: + clamp.in_port(0).get_connection().set_destination(min_node.in_port(0)) + clamp.out_port(0).get_connection().set_source(min_node.out_port(0)) + clamp.in_port(2).get_connection().set_destination(min_node.in_port(1)) + assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used' + rename_node(max_node if min_node is None else min_node, name) + else: + a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node() + rename_node(a_clamp, name) + clamp.in_port(0).get_connection().set_destination(a_clamp.in_port(0)) + clamp.out_port(0).get_connection().set_source(a_clamp.out_port(0)) diff --git a/model-optimizer/extensions/back/ClampNormalizer_test.py b/model-optimizer/extensions/back/ClampNormalizer_test.py new file mode 100644 index 00000000000..eb5b07e7a53 --- /dev/null +++ b/model-optimizer/extensions/back/ClampNormalizer_test.py @@ -0,0 +1,118 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import unittest + +import numpy as np + +from extensions.back.ClampNormalizer import ClampNormalizer +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, connect + + +class AttributedClampNormalizerTests(unittest.TestCase): + + def test_2_inputs(self): + nodes = { + **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), + **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), + **regular_op_with_shaped_data('clamp', [1, 3, 20, 20], + {'type': 'Clamp', 'op': 'AttributedClamp', 'min': -3.5, 'max': 3.5}), + **valued_const_with_data('min', np.array(-3.5)), + **valued_const_with_data('max', np.array(3.5)), + **result('result'), + } + edges = [*connect('placeholder', '0:a_clamp'), + *connect('min', '1:a_clamp'), + *connect('max', '2:a_clamp'), + *connect('a_clamp', 'result'), + ] + graph = build_graph(nodes, edges) + ClampNormalizer().find_and_replace_pattern(graph) + ref_graph = build_graph(nodes, [*connect('placeholder', '0:clamp'), *connect('clamp', 'result')]) + + (flag, resp) = compare_graphs(graph, ref_graph, 'result') + self.assertTrue(flag, resp) + + def test_all_dynamic_inputs(self): + nodes = { + **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), + **regular_op_with_shaped_data('min', [1, 3, 20, 20], {'type': 'Parameter'}), + **regular_op_with_shaped_data('max', [1, 3, 20, 20], {'type': 'Parameter'}), + **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), + **regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}), + **regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}), + **result('result'), + } + edges = [*connect('placeholder', '0:a_clamp'), + *connect('min', '1:a_clamp'), + *connect('max', '2:a_clamp'), + *connect('a_clamp', 'result'), + ] + graph = build_graph(nodes, edges) + ClampNormalizer().find_and_replace_pattern(graph) + ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'), + *connect('min', '1:maximum'), + *connect('maximum', '0:minimum'), + *connect('max', '1:minimum'), + *connect('minimum', 'result') + ]) + + (flag, resp) = compare_graphs(graph, ref_graph, 'result') + self.assertTrue(flag, resp) + + def test_no_2nd_input(self): + nodes = { + **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), + **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), + **regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}), + **valued_const_with_data('min', np.array(-3.5)), + **result('result'), + } + edges = [*connect('placeholder', '0:a_clamp'), + *connect('min', '1:a_clamp'), + *connect('a_clamp', 'result'), + ] + graph = build_graph(nodes, edges) + ClampNormalizer().find_and_replace_pattern(graph) + ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'), + *connect('min', '1:maximum'), + *connect('maximum', 'result') + ]) + + (flag, resp) = compare_graphs(graph, ref_graph, 'result') + self.assertTrue(flag, resp) + + def test_no_1st_input(self): + nodes = { + **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), + **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), + **regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}), + **valued_const_with_data('max', np.array(3.5)), + **result('result'), + } + edges = [*connect('placeholder', '0:a_clamp'), + *connect('max', '2:a_clamp'), + *connect('a_clamp', 'result'), + ] + graph = build_graph(nodes, edges) + ClampNormalizer().find_and_replace_pattern(graph) + ref_graph = build_graph(nodes, [*connect('placeholder', '0:minimum'), + *connect('max', '1:minimum'), + *connect('minimum', 'result') + ]) + + (flag, resp) = compare_graphs(graph, ref_graph, 'result') + self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/front/AttributedClampNormalizer.py b/model-optimizer/extensions/front/AttributedClampNormalizer.py new file mode 100644 index 00000000000..0b5559666f6 --- /dev/null +++ b/model-optimizer/extensions/front/AttributedClampNormalizer.py @@ -0,0 +1,46 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import numpy as np + +from mo.front.common.replacement import FrontReplacementPattern +from mo.front.tf.graph_utils import create_op_with_const_inputs +from mo.graph.graph import Graph, rename_node +from mo.ops.clamp import Clamp + + +class AttributedClampNormalizer(FrontReplacementPattern): + """ + This transformation converts AttributedClamp operation (min/max are specified as attribute) to Clamp + operation. + """ + enabled = True + + def find_and_replace_pattern(self, graph: Graph): + for attr_clamp in graph.get_op_nodes(op='AttributedClamp'): + original_name = attr_clamp.soft_get('name', attr_clamp.id) + + rename_node(attr_clamp, original_name + '/TBR') + min_value = attr_clamp.soft_get('min', np.finfo(np.float32).min) + max_value = attr_clamp.soft_get('max', np.finfo(np.float32).max) + new_clamp = create_op_with_const_inputs(graph, Clamp, + {1: np.array(min_value, dtype=np.float32), + 2: np.array(max_value, dtype=np.float32)}, + {'name': original_name}) + rename_node(new_clamp, original_name) + + attr_clamp.in_port(0).get_connection().set_destination(new_clamp.in_port(0)) + attr_clamp.out_port(0).get_connection().set_source(new_clamp.out_port(0)) + graph.remove_node(attr_clamp.id) diff --git a/model-optimizer/extensions/front/AttributedClampNormalizer_test.py b/model-optimizer/extensions/front/AttributedClampNormalizer_test.py new file mode 100644 index 00000000000..0eb1bb7b4a8 --- /dev/null +++ b/model-optimizer/extensions/front/AttributedClampNormalizer_test.py @@ -0,0 +1,62 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +import numpy as np + +from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, const + +nodes_attributes = { + 'placeholder': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'attr_clamp': {'type': 'Clamp', 'kind': 'op', 'op': 'AttributedClamp', 'name': 'attr_clamp', + 'min': np.array(-3.5, dtype=np.float32), 'max': np.array(3.5, dtype=np.float32)}, + 'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'}, + + # new Clamp layer and inputs + 'clamp': {'type': None, 'kind': 'op', 'op': 'Clamp'}, + **const('min', np.array(-3.5, dtype=np.float32)), + **const('max', np.array(3.5, dtype=np.float32)), +} + + +class AttributedClampNormalizerTest(unittest.TestCase): + def test_1(self): + graph = build_graph(nodes_attributes, + [('placeholder', 'attr_clamp', {'in': 0, 'out': 0}), + ('attr_clamp', 'result', {'in': 0, 'out': 0}), + ], + {}, nodes_with_edges_only=True) + + graph_ref = build_graph(nodes_attributes, + [('placeholder', 'clamp', {'in': 0, 'out': 0}), + ('min', 'clamp', {'in': 1, 'out': 0}), + ('max', 'clamp', {'in': 2, 'out': 0}), + ('clamp', 'result') + ], + {}, nodes_with_edges_only=True) + + graph.graph['layout'] = 'NCHW' + graph.stage = 'front' + + replacer = AttributedClampNormalizer() + replacer.find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) + self.assertTrue(flag, resp) + self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='Clamp')[0]]['name'] == 'attr_clamp') diff --git a/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py b/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py index 05c85002324..8f29bae85d4 100644 --- a/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py +++ b/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py @@ -22,13 +22,14 @@ from extensions.ops.split import Split from mo.front.caffe.extractors.utils import input_as_const from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementOp +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Node, Graph, Port from mo.ops.assign import Assign from mo.ops.broadcast import Broadcast from mo.ops.clamp import Clamp -from mo.ops.crop import Crop from mo.ops.concat import Concat from mo.ops.const import Const +from mo.ops.crop import Crop from mo.ops.read_value import ReadValue from mo.ops.result import Result from mo.ops.scale_shift import ScaleShiftOp @@ -238,10 +239,10 @@ class ReplaceLSTMNodePattern(FrontReplacementOp): join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0)) # (7)Eltwise(sum) -> Clamp - join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp', - 'max': node.clip_value, - 'min': -node.clip_value}).create_node() - join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0)) + join_forget_clamp = create_op_with_const_inputs(graph, Clamp, {1: np.array(-node.clip_value, dtype=np.float32), + 2: np.array(node.clip_value, dtype=np.float32)}, + {'name': 'join_forget_clamp'}, + join_forget_remember_sum) # # Clamp -> (2)Memory(state) next_lstm_state = Assign(graph, {'name': 'next_lstm_state', diff --git a/model-optimizer/extensions/front/mxnet/clip_ext.py b/model-optimizer/extensions/front/mxnet/clip_ext.py index dbb0af863b1..ad8b2baa646 100644 --- a/model-optimizer/extensions/front/mxnet/clip_ext.py +++ b/model-optimizer/extensions/front/mxnet/clip_ext.py @@ -16,7 +16,7 @@ from mo.front.extractor import FrontExtractorOp from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs from mo.graph.graph import Node -from mo.ops.clamp import Clamp +from mo.ops.clamp import AttributedClamp class ClipExt(FrontExtractorOp): @@ -27,5 +27,5 @@ class ClipExt(FrontExtractorOp): def extract(cls, node: Node): attrs = get_mxnet_layer_attrs(node.symbol_dict) - Clamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None),}) + AttributedClamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None)}) return cls.enabled diff --git a/model-optimizer/extensions/front/mxnet/cumsum.py b/model-optimizer/extensions/front/mxnet/cumsum.py index db98c65bc73..84717e6dba6 100644 --- a/model-optimizer/extensions/front/mxnet/cumsum.py +++ b/model-optimizer/extensions/front/mxnet/cumsum.py @@ -38,7 +38,7 @@ class CumSumFrontReplacer(FrontReplacementOp): node.in_port(0).get_connection().set_destination(cumsum_node.in_port(0)) if node.has_valid('mx_out_type') and node['mx_out_type'] is not None: - rename_node(node=cumsum_node, name=name + '/Clamp') + rename_node(node=cumsum_node, name=name + '/CumSum') convert = Cast(graph, {'name': name, 'dst_type': node['mx_out_type']}).create_node() rename_node(convert, name) cumsum_node.out_port(0).connect(convert.in_port(0)) diff --git a/model-optimizer/extensions/front/onnx/clip_ext.py b/model-optimizer/extensions/front/onnx/clip_ext.py index 748d1fbff0a..1883b78a13c 100644 --- a/model-optimizer/extensions/front/onnx/clip_ext.py +++ b/model-optimizer/extensions/front/onnx/clip_ext.py @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ +import numpy as np from mo.front.extractor import FrontExtractorOp -from mo.front.onnx.extractors.utils import onnx_attr -from mo.ops.clamp import Clamp +from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version +from mo.ops.clamp import Clamp, AttributedClamp class ClipFrontExtractor(FrontExtractorOp): @@ -25,9 +26,12 @@ class ClipFrontExtractor(FrontExtractorOp): @classmethod def extract(cls, node): - attrs = { - 'min': onnx_attr(node, 'min', 'f', -3.4028234663852886e+38), - 'max': onnx_attr(node, 'max', 'f', 3.4028234663852886e+38), - } - Clamp.update_node_stat(node, attrs) + if get_onnx_opset_version(node) < 11: + attrs = { + 'min': onnx_attr(node, 'min', 'f', np.finfo(np.float32).min), + 'max': onnx_attr(node, 'max', 'f', np.finfo(np.float32).max), + } + AttributedClamp.update_node_stat(node, attrs) + else: + Clamp.update_node_stat(node) return cls.enabled diff --git a/model-optimizer/extensions/load/onnx/loader.py b/model-optimizer/extensions/load/onnx/loader.py index 6da5eacf572..c92058bb78a 100644 --- a/model-optimizer/extensions/load/onnx/loader.py +++ b/model-optimizer/extensions/load/onnx/loader.py @@ -63,6 +63,10 @@ class ONNXLoader(Loader): graph.graph['layout'] = 'NCHW' graph.graph['fw'] = 'onnx' graph.graph['feature_dim'] = 1 + if hasattr(model_proto, 'opset_import'): + graph.graph['fw_opset_version'] = model_proto.opset_import[0].version # pylint: disable=no-member + else: + graph.graph['fw_opset_version'] = None graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model') extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors))) diff --git a/model-optimizer/extensions/middle/ReluQuantizeFuse.py b/model-optimizer/extensions/middle/ReluQuantizeFuse.py index ca89b4c78b9..3e025f01402 100644 --- a/model-optimizer/extensions/middle/ReluQuantizeFuse.py +++ b/model-optimizer/extensions/middle/ReluQuantizeFuse.py @@ -120,7 +120,11 @@ class ClampQuantizeMark(MiddleReplacementPattern): def replace_pattern(self, graph: Graph, match: Dict[str, Node]): clamp = match['clamp'] quantize = match['quantize'] - clamp_min, clamp_max = clamp['min'], clamp['max'] + clamp_min = clamp.in_port(1).data.get_value() + clamp_max = clamp.in_port(2).data.get_value() + if clamp_min is None or clamp_max is None: + log.debug('ReluQuantizeFuse: cannot fuse because Clamp op has dynamic input on the 1st or 2nd port') + return if not clamp.has_valid('quantized_to_fuse_count'): clamp['quantized_to_fuse_count'] = 0 diff --git a/model-optimizer/extensions/ops/activation_ops.py b/model-optimizer/extensions/ops/activation_ops.py index cd4282bf4ae..877acaae9d2 100644 --- a/model-optimizer/extensions/ops/activation_ops.py +++ b/model-optimizer/extensions/ops/activation_ops.py @@ -18,7 +18,7 @@ import numpy as np from mo.front.common.partial_infer.eltwise import eltwise_infer from mo.graph.graph import Graph, Node -from mo.ops.clamp import Clamp +from mo.ops.clamp import AttributedClamp from mo.ops.op import Op activation_ops = ['Sigmoid', 'Tanh', 'ReLU6', 'Exp', 'Elu', 'LogicalNot', 'Floor', 'Ceiling'] @@ -95,7 +95,7 @@ class Atan(Activation): operation = staticmethod(lambda x: np.arctan(x)) -class ReLU6(Clamp): +class ReLU6(AttributedClamp): op = 'ReLU6' def __init__(self, graph: Graph, attrs: dict): diff --git a/model-optimizer/mo/front/onnx/extractors/utils.py b/model-optimizer/mo/front/onnx/extractors/utils.py index baf9ea3a3bf..6bff1a79c44 100644 --- a/model-optimizer/mo/front/onnx/extractors/utils.py +++ b/model-optimizer/mo/front/onnx/extractors/utils.py @@ -57,6 +57,10 @@ def get_onnx_autopad(auto_pad): return auto_pad +def get_onnx_opset_version(node: Node): + return node.graph.graph.get('fw_opset_version', 0) + + def get_onnx_datatype_as_numpy(value): datatype_to_numpy = { 1: np.float32, diff --git a/model-optimizer/mo/ops/clamp.py b/model-optimizer/mo/ops/clamp.py index 8205bc97116..bd6cd76fa1c 100644 --- a/model-optimizer/mo/ops/clamp.py +++ b/model-optimizer/mo/ops/clamp.py @@ -19,13 +19,13 @@ from mo.graph.graph import Graph from mo.ops.op import Op -class Clamp(Op): - op = 'Clamp' +class AttributedClamp(Op): + op = 'AttributedClamp' def __init__(self, graph: Graph, attrs: dict): super().__init__(graph, { - 'type': __class__.op, - 'op': __class__.op, + 'type': 'Clamp', + 'op': self.op, 'version': 'opset1', 'infer': copy_shape_infer, 'in_ports_count': 1, @@ -37,3 +37,16 @@ class Clamp(Op): 'max', 'min' ] + + +class Clamp(Op): + op = 'Clamp' + + def __init__(self, graph: Graph, attrs: dict): + super().__init__(graph, { + 'type': None, + 'op': self.op, + 'infer': copy_shape_infer, + 'in_ports_count': 3, + 'out_ports_count': 1, + }, attrs) diff --git a/model-optimizer/mo/ops/clamp_test.py b/model-optimizer/mo/ops/clamp_test.py index 5cd7927cdab..71ebffe371a 100644 --- a/model-optimizer/mo/ops/clamp_test.py +++ b/model-optimizer/mo/ops/clamp_test.py @@ -19,7 +19,7 @@ import unittest import numpy as np from mo.front.common.partial_infer.elemental import copy_shape_infer -from mo.ops.clamp import Clamp +from mo.ops.clamp import AttributedClamp from mo.utils.unittest.graph import build_graph @@ -41,7 +41,7 @@ class TestClampOp(unittest.TestCase): ('node_1', 'clamp_node'), ('clamp_node', 'node_3') ]) - clamp_node = Clamp(graph, self.nodes_attributes['clamp_node']).add_node() + clamp_node = AttributedClamp(graph, self.nodes_attributes['clamp_node']).add_node() self.assertEqual(clamp_node.type, 'Clamp') - self.assertEqual(clamp_node.op, 'Clamp') + self.assertEqual(clamp_node.op, 'AttributedClamp') self.assertEqual(clamp_node.infer, copy_shape_infer) diff --git a/model-optimizer/mo/utils/ir_reader/layer_to_class.py b/model-optimizer/mo/utils/ir_reader/layer_to_class.py index 83b63f9afc7..c9ad763c87f 100644 --- a/model-optimizer/mo/utils/ir_reader/layer_to_class.py +++ b/model-optimizer/mo/utils/ir_reader/layer_to_class.py @@ -28,6 +28,7 @@ from extensions.ops.scatter import Scatter from extensions.ops.split import Split, VariadicSplit from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Graph, Node +from mo.ops.clamp import AttributedClamp from mo.ops.convolution import Convolution from mo.ops.deconvolution import Deconvolution from mo.ops.op import Op @@ -53,6 +54,7 @@ custom_ops = { 'Split': Split, 'Subtract': Sub, 'VariadicSplit': VariadicSplit, + 'Clamp': AttributedClamp, }