Support ONNX Clamp-11 (#538)
This commit is contained in:
parent
04bb8ab51d
commit
8c8629a4af
@ -12,6 +12,7 @@ extensions/back/ActivationsNormalizer.py
|
||||
extensions/back/AvgPool.py
|
||||
extensions/back/blob_normalizer.py
|
||||
extensions/back/CellNormalizer.py
|
||||
extensions/back/ClampNormalizer.py
|
||||
extensions/back/compress_quantized_weights.py
|
||||
extensions/back/ConvolutionNormalizer.py
|
||||
extensions/back/CorrectName.py
|
||||
@ -72,6 +73,7 @@ extensions/back/UselessConcatRemoval.py
|
||||
extensions/front/__init__.py
|
||||
extensions/front/ArgMaxSqueeze.py
|
||||
extensions/front/ATenToEmbeddingBag.py
|
||||
extensions/front/AttributedClampNormalizer.py
|
||||
extensions/front/AttributedGatherNormalizer.py
|
||||
extensions/front/AttributedPadToPad.py
|
||||
extensions/front/binary_quantize_normalization.py
|
||||
|
71
model-optimizer/extensions/back/ClampNormalizer.py
Normal file
71
model-optimizer/extensions/back/ClampNormalizer.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.elementwise import Minimum, Maximum
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
from mo.ops.clamp import AttributedClamp
|
||||
|
||||
|
||||
class ClampNormalizer(BackReplacementPattern):
|
||||
"""
|
||||
Replaces Clamp with `min` and `max` as inputs with AttributedClamp with `min` and `max` as attributes.
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[('clamp', dict(op='Clamp'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
clamp = match['clamp']
|
||||
name = clamp.soft_get('name', clamp.id)
|
||||
|
||||
min_value = max_value = None
|
||||
port_1_exist = clamp.has_port('in', 1) and not clamp.in_port(1).disconnected()
|
||||
port_2_exist = clamp.has_port('in', 2) and not clamp.in_port(2).disconnected()
|
||||
if port_1_exist and clamp.in_port(1).get_source().node.soft_get('type') == 'Const':
|
||||
min_value = clamp.in_port(1).data.get_value()
|
||||
if port_2_exist and clamp.in_port(2).get_source().node.soft_get('type') == 'Const':
|
||||
max_value = clamp.in_port(2).data.get_value()
|
||||
|
||||
rename_node(clamp, name + '/TBR')
|
||||
if min_value is None or max_value is None:
|
||||
max_node = min_node = None
|
||||
if port_1_exist:
|
||||
max_node = Maximum(graph, {}).create_node()
|
||||
clamp.in_port(0).get_connection().set_destination(max_node.in_port(0))
|
||||
clamp.in_port(1).get_connection().set_destination(max_node.in_port(1))
|
||||
clamp.out_port(0).get_connection().set_source(max_node.out_port(0))
|
||||
if port_2_exist:
|
||||
min_node = Minimum(graph, {}).create_node()
|
||||
if max_node is not None:
|
||||
max_node.out_port(0).get_connection().set_source(min_node.out_port(0))
|
||||
max_node.out_port(0).connect(min_node.in_port(0))
|
||||
else:
|
||||
clamp.in_port(0).get_connection().set_destination(min_node.in_port(0))
|
||||
clamp.out_port(0).get_connection().set_source(min_node.out_port(0))
|
||||
clamp.in_port(2).get_connection().set_destination(min_node.in_port(1))
|
||||
assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used'
|
||||
rename_node(max_node if min_node is None else min_node, name)
|
||||
else:
|
||||
a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node()
|
||||
rename_node(a_clamp, name)
|
||||
clamp.in_port(0).get_connection().set_destination(a_clamp.in_port(0))
|
||||
clamp.out_port(0).get_connection().set_source(a_clamp.out_port(0))
|
118
model-optimizer/extensions/back/ClampNormalizer_test.py
Normal file
118
model-optimizer/extensions/back/ClampNormalizer_test.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.ClampNormalizer import ClampNormalizer
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, connect
|
||||
|
||||
|
||||
class AttributedClampNormalizerTests(unittest.TestCase):
|
||||
|
||||
def test_2_inputs(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
|
||||
**regular_op_with_shaped_data('clamp', [1, 3, 20, 20],
|
||||
{'type': 'Clamp', 'op': 'AttributedClamp', 'min': -3.5, 'max': 3.5}),
|
||||
**valued_const_with_data('min', np.array(-3.5)),
|
||||
**valued_const_with_data('max', np.array(3.5)),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [*connect('placeholder', '0:a_clamp'),
|
||||
*connect('min', '1:a_clamp'),
|
||||
*connect('max', '2:a_clamp'),
|
||||
*connect('a_clamp', 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
ClampNormalizer().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes, [*connect('placeholder', '0:clamp'), *connect('clamp', 'result')])
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_all_dynamic_inputs(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('min', [1, 3, 20, 20], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('max', [1, 3, 20, 20], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
|
||||
**regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}),
|
||||
**regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [*connect('placeholder', '0:a_clamp'),
|
||||
*connect('min', '1:a_clamp'),
|
||||
*connect('max', '2:a_clamp'),
|
||||
*connect('a_clamp', 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
ClampNormalizer().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'),
|
||||
*connect('min', '1:maximum'),
|
||||
*connect('maximum', '0:minimum'),
|
||||
*connect('max', '1:minimum'),
|
||||
*connect('minimum', 'result')
|
||||
])
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_no_2nd_input(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
|
||||
**regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {'type': 'Maximum', 'op': 'Maximum'}),
|
||||
**valued_const_with_data('min', np.array(-3.5)),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [*connect('placeholder', '0:a_clamp'),
|
||||
*connect('min', '1:a_clamp'),
|
||||
*connect('a_clamp', 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
ClampNormalizer().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes, [*connect('placeholder', '0:maximum'),
|
||||
*connect('min', '1:maximum'),
|
||||
*connect('maximum', 'result')
|
||||
])
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_no_1st_input(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
|
||||
**regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {'type': 'Minimum', 'op': 'Minimum'}),
|
||||
**valued_const_with_data('max', np.array(3.5)),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [*connect('placeholder', '0:a_clamp'),
|
||||
*connect('max', '2:a_clamp'),
|
||||
*connect('a_clamp', 'result'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
ClampNormalizer().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes, [*connect('placeholder', '0:minimum'),
|
||||
*connect('max', '1:minimum'),
|
||||
*connect('minimum', 'result')
|
||||
])
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
|
||||
self.assertTrue(flag, resp)
|
@ -0,0 +1,46 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
from mo.ops.clamp import Clamp
|
||||
|
||||
|
||||
class AttributedClampNormalizer(FrontReplacementPattern):
|
||||
"""
|
||||
This transformation converts AttributedClamp operation (min/max are specified as attribute) to Clamp
|
||||
operation.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for attr_clamp in graph.get_op_nodes(op='AttributedClamp'):
|
||||
original_name = attr_clamp.soft_get('name', attr_clamp.id)
|
||||
|
||||
rename_node(attr_clamp, original_name + '/TBR')
|
||||
min_value = attr_clamp.soft_get('min', np.finfo(np.float32).min)
|
||||
max_value = attr_clamp.soft_get('max', np.finfo(np.float32).max)
|
||||
new_clamp = create_op_with_const_inputs(graph, Clamp,
|
||||
{1: np.array(min_value, dtype=np.float32),
|
||||
2: np.array(max_value, dtype=np.float32)},
|
||||
{'name': original_name})
|
||||
rename_node(new_clamp, original_name)
|
||||
|
||||
attr_clamp.in_port(0).get_connection().set_destination(new_clamp.in_port(0))
|
||||
attr_clamp.out_port(0).get_connection().set_source(new_clamp.out_port(0))
|
||||
graph.remove_node(attr_clamp.id)
|
@ -0,0 +1,62 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, const
|
||||
|
||||
nodes_attributes = {
|
||||
'placeholder': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'attr_clamp': {'type': 'Clamp', 'kind': 'op', 'op': 'AttributedClamp', 'name': 'attr_clamp',
|
||||
'min': np.array(-3.5, dtype=np.float32), 'max': np.array(3.5, dtype=np.float32)},
|
||||
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
|
||||
# new Clamp layer and inputs
|
||||
'clamp': {'type': None, 'kind': 'op', 'op': 'Clamp'},
|
||||
**const('min', np.array(-3.5, dtype=np.float32)),
|
||||
**const('max', np.array(3.5, dtype=np.float32)),
|
||||
}
|
||||
|
||||
|
||||
class AttributedClampNormalizerTest(unittest.TestCase):
|
||||
def test_1(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder', 'attr_clamp', {'in': 0, 'out': 0}),
|
||||
('attr_clamp', 'result', {'in': 0, 'out': 0}),
|
||||
],
|
||||
{}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder', 'clamp', {'in': 0, 'out': 0}),
|
||||
('min', 'clamp', {'in': 1, 'out': 0}),
|
||||
('max', 'clamp', {'in': 2, 'out': 0}),
|
||||
('clamp', 'result')
|
||||
],
|
||||
{}, nodes_with_edges_only=True)
|
||||
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph.stage = 'front'
|
||||
|
||||
replacer = AttributedClampNormalizer()
|
||||
replacer.find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='Clamp')[0]]['name'] == 'attr_clamp')
|
@ -22,13 +22,14 @@ from extensions.ops.split import Split
|
||||
from mo.front.caffe.extractors.utils import input_as_const
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Node, Graph, Port
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
@ -238,10 +239,10 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
|
||||
|
||||
# (7)Eltwise(sum) -> Clamp
|
||||
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
|
||||
'max': node.clip_value,
|
||||
'min': -node.clip_value}).create_node()
|
||||
join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
|
||||
join_forget_clamp = create_op_with_const_inputs(graph, Clamp, {1: np.array(-node.clip_value, dtype=np.float32),
|
||||
2: np.array(node.clip_value, dtype=np.float32)},
|
||||
{'name': 'join_forget_clamp'},
|
||||
join_forget_remember_sum)
|
||||
#
|
||||
# Clamp -> (2)Memory(state)
|
||||
next_lstm_state = Assign(graph, {'name': 'next_lstm_state',
|
||||
|
@ -16,7 +16,7 @@
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
|
||||
from mo.graph.graph import Node
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.ops.clamp import AttributedClamp
|
||||
|
||||
|
||||
class ClipExt(FrontExtractorOp):
|
||||
@ -27,5 +27,5 @@ class ClipExt(FrontExtractorOp):
|
||||
def extract(cls, node: Node):
|
||||
attrs = get_mxnet_layer_attrs(node.symbol_dict)
|
||||
|
||||
Clamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None),})
|
||||
AttributedClamp.update_node_stat(node, {'min': attrs.float('a_min', None), 'max': attrs.float('a_max', None)})
|
||||
return cls.enabled
|
||||
|
@ -38,7 +38,7 @@ class CumSumFrontReplacer(FrontReplacementOp):
|
||||
|
||||
node.in_port(0).get_connection().set_destination(cumsum_node.in_port(0))
|
||||
if node.has_valid('mx_out_type') and node['mx_out_type'] is not None:
|
||||
rename_node(node=cumsum_node, name=name + '/Clamp')
|
||||
rename_node(node=cumsum_node, name=name + '/CumSum')
|
||||
convert = Cast(graph, {'name': name, 'dst_type': node['mx_out_type']}).create_node()
|
||||
rename_node(convert, name)
|
||||
cumsum_node.out_port(0).connect(convert.in_port(0))
|
||||
|
@ -13,10 +13,11 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
|
||||
from mo.ops.clamp import Clamp, AttributedClamp
|
||||
|
||||
|
||||
class ClipFrontExtractor(FrontExtractorOp):
|
||||
@ -25,9 +26,12 @@ class ClipFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
attrs = {
|
||||
'min': onnx_attr(node, 'min', 'f', -3.4028234663852886e+38),
|
||||
'max': onnx_attr(node, 'max', 'f', 3.4028234663852886e+38),
|
||||
}
|
||||
Clamp.update_node_stat(node, attrs)
|
||||
if get_onnx_opset_version(node) < 11:
|
||||
attrs = {
|
||||
'min': onnx_attr(node, 'min', 'f', np.finfo(np.float32).min),
|
||||
'max': onnx_attr(node, 'max', 'f', np.finfo(np.float32).max),
|
||||
}
|
||||
AttributedClamp.update_node_stat(node, attrs)
|
||||
else:
|
||||
Clamp.update_node_stat(node)
|
||||
return cls.enabled
|
||||
|
@ -63,6 +63,10 @@ class ONNXLoader(Loader):
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
graph.graph['fw'] = 'onnx'
|
||||
graph.graph['feature_dim'] = 1
|
||||
if hasattr(model_proto, 'opset_import'):
|
||||
graph.graph['fw_opset_version'] = model_proto.opset_import[0].version # pylint: disable=no-member
|
||||
else:
|
||||
graph.graph['fw_opset_version'] = None
|
||||
|
||||
graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
|
||||
extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
|
||||
|
@ -120,7 +120,11 @@ class ClampQuantizeMark(MiddleReplacementPattern):
|
||||
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
||||
clamp = match['clamp']
|
||||
quantize = match['quantize']
|
||||
clamp_min, clamp_max = clamp['min'], clamp['max']
|
||||
clamp_min = clamp.in_port(1).data.get_value()
|
||||
clamp_max = clamp.in_port(2).data.get_value()
|
||||
if clamp_min is None or clamp_max is None:
|
||||
log.debug('ReluQuantizeFuse: cannot fuse because Clamp op has dynamic input on the 1st or 2nd port')
|
||||
return
|
||||
|
||||
if not clamp.has_valid('quantized_to_fuse_count'):
|
||||
clamp['quantized_to_fuse_count'] = 0
|
||||
|
@ -18,7 +18,7 @@ import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.eltwise import eltwise_infer
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.ops.clamp import AttributedClamp
|
||||
from mo.ops.op import Op
|
||||
|
||||
activation_ops = ['Sigmoid', 'Tanh', 'ReLU6', 'Exp', 'Elu', 'LogicalNot', 'Floor', 'Ceiling']
|
||||
@ -95,7 +95,7 @@ class Atan(Activation):
|
||||
operation = staticmethod(lambda x: np.arctan(x))
|
||||
|
||||
|
||||
class ReLU6(Clamp):
|
||||
class ReLU6(AttributedClamp):
|
||||
op = 'ReLU6'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
|
@ -57,6 +57,10 @@ def get_onnx_autopad(auto_pad):
|
||||
return auto_pad
|
||||
|
||||
|
||||
def get_onnx_opset_version(node: Node):
|
||||
return node.graph.graph.get('fw_opset_version', 0)
|
||||
|
||||
|
||||
def get_onnx_datatype_as_numpy(value):
|
||||
datatype_to_numpy = {
|
||||
1: np.float32,
|
||||
|
@ -19,13 +19,13 @@ from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class Clamp(Op):
|
||||
op = 'Clamp'
|
||||
class AttributedClamp(Op):
|
||||
op = 'AttributedClamp'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': __class__.op,
|
||||
'op': __class__.op,
|
||||
'type': 'Clamp',
|
||||
'op': self.op,
|
||||
'version': 'opset1',
|
||||
'infer': copy_shape_infer,
|
||||
'in_ports_count': 1,
|
||||
@ -37,3 +37,16 @@ class Clamp(Op):
|
||||
'max',
|
||||
'min'
|
||||
]
|
||||
|
||||
|
||||
class Clamp(Op):
|
||||
op = 'Clamp'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'infer': copy_shape_infer,
|
||||
'in_ports_count': 3,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
||||
|
@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.ops.clamp import AttributedClamp
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class TestClampOp(unittest.TestCase):
|
||||
('node_1', 'clamp_node'),
|
||||
('clamp_node', 'node_3')
|
||||
])
|
||||
clamp_node = Clamp(graph, self.nodes_attributes['clamp_node']).add_node()
|
||||
clamp_node = AttributedClamp(graph, self.nodes_attributes['clamp_node']).add_node()
|
||||
self.assertEqual(clamp_node.type, 'Clamp')
|
||||
self.assertEqual(clamp_node.op, 'Clamp')
|
||||
self.assertEqual(clamp_node.op, 'AttributedClamp')
|
||||
self.assertEqual(clamp_node.infer, copy_shape_infer)
|
||||
|
@ -28,6 +28,7 @@ from extensions.ops.scatter import Scatter
|
||||
from extensions.ops.split import Split, VariadicSplit
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.clamp import AttributedClamp
|
||||
from mo.ops.convolution import Convolution
|
||||
from mo.ops.deconvolution import Deconvolution
|
||||
from mo.ops.op import Op
|
||||
@ -53,6 +54,7 @@ custom_ops = {
|
||||
'Split': Split,
|
||||
'Subtract': Sub,
|
||||
'VariadicSplit': VariadicSplit,
|
||||
'Clamp': AttributedClamp,
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user