[ MO ] Clamp value inference (#1207)
This commit is contained in:
parent
951b5eed92
commit
143036f96f
@ -13,8 +13,8 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
@ -27,7 +27,7 @@ class AttributedClamp(Op):
|
||||
'type': 'Clamp',
|
||||
'op': self.op,
|
||||
'version': 'opset1',
|
||||
'infer': copy_shape_infer,
|
||||
'infer': self.infer,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
||||
@ -38,6 +38,23 @@ class AttributedClamp(Op):
|
||||
'min'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def infer(node):
|
||||
name = node.soft_get('name', node.id)
|
||||
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
|
||||
assert len(connected_in_ports) == 1 and connected_in_ports[0].idx == 0, \
|
||||
'AttributedClamp should have only one input, but it has {}'.format(len(connected_in_ports))
|
||||
assert node.has_valid('max') and node.has_valid('min'), \
|
||||
'Mandatory attributes `max` and `min` were not set for AttributedClamp node: `{}`'.format(name)
|
||||
assert node.max >= node.min, \
|
||||
'AttributedClamp max=={} is less than min=={} for node `{}`'.format(node.max, node.min, name)
|
||||
|
||||
if node.in_port(0).data.get_value() is not None:
|
||||
node.out_port(0).data.set_value(np.clip(node.in_port(0).data.get_value(), node['min'], node['max']))
|
||||
else:
|
||||
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
||||
|
||||
|
||||
class Clamp(Op):
|
||||
op = 'Clamp'
|
||||
@ -46,7 +63,25 @@ class Clamp(Op):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'infer': copy_shape_infer,
|
||||
'infer': self.infer,
|
||||
'in_ports_count': 3,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
||||
|
||||
@staticmethod
|
||||
def infer(node):
|
||||
name = node.soft_get('name', node.id)
|
||||
connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()]
|
||||
|
||||
assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \
|
||||
'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports))
|
||||
|
||||
input_value = node.in_port(0).data.get_value()
|
||||
min_value = node.in_port(1).data.get_value()
|
||||
max_value = node.in_port(2).data.get_value()
|
||||
if input_value is not None and min_value is not None and max_value is not None:
|
||||
assert np.all(max_value >= min_value), \
|
||||
'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name)
|
||||
node.out_port(0).data.set_value(np.clip(node.in_port(0).data.get_value(), min_value, max_value))
|
||||
else:
|
||||
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
||||
|
@ -1,47 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.ops.clamp import AttributedClamp
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
|
||||
class TestClampOp(unittest.TestCase):
|
||||
nodes_attributes = {
|
||||
'node_1': {
|
||||
'shape': np.array([227, 227, 227, 227])
|
||||
},
|
||||
'clamp_node': {
|
||||
},
|
||||
'node_3': {
|
||||
'kind': 'data'
|
||||
}
|
||||
}
|
||||
|
||||
def test_clamp_op(self):
|
||||
graph = build_graph(self.nodes_attributes,
|
||||
[
|
||||
('node_1', 'clamp_node'),
|
||||
('clamp_node', 'node_3')
|
||||
])
|
||||
clamp_node = AttributedClamp(graph, self.nodes_attributes['clamp_node']).add_node()
|
||||
self.assertEqual(clamp_node.type, 'Clamp')
|
||||
self.assertEqual(clamp_node.op, 'AttributedClamp')
|
||||
self.assertEqual(clamp_node.infer, copy_shape_infer)
|
Loading…
Reference in New Issue
Block a user