[ MO ] Clamp value inference (#1207)

This commit is contained in:
Evgenya Stepyreva 2020-07-03 17:57:10 +03:00 committed by GitHub
parent 951b5eed92
commit 143036f96f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 50 deletions

View File

@ -13,8 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.graph.graph import Graph
from mo.ops.op import Op
@ -27,7 +27,7 @@ class AttributedClamp(Op):
'type': 'Clamp',
'op': self.op,
'version': 'opset1',
'infer': copy_shape_infer,
'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
@ -38,6 +38,23 @@ class AttributedClamp(Op):
'min'
]
@staticmethod
def infer(node):
name = node.soft_get('name', node.id)
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
assert len(connected_in_ports) == 1 and connected_in_ports[0].idx == 0, \
'AttributedClamp should have only one input, but it has {}'.format(len(connected_in_ports))
assert node.has_valid('max') and node.has_valid('min'), \
'Mandatory attributes `max` and `min` were not set for AttributedClamp node: `{}`'.format(name)
assert node.max >= node.min, \
'AttributedClamp max=={} is less than min=={} for node `{}`'.format(node.max, node.min, name)
if node.in_port(0).data.get_value() is not None:
node.out_port(0).data.set_value(np.clip(node.in_port(0).data.get_value(), node['min'], node['max']))
else:
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
class Clamp(Op):
op = 'Clamp'
@ -46,7 +63,25 @@ class Clamp(Op):
super().__init__(graph, {
'type': None,
'op': self.op,
'infer': copy_shape_infer,
'infer': self.infer,
'in_ports_count': 3,
'out_ports_count': 1,
}, attrs)
@staticmethod
def infer(node):
name = node.soft_get('name', node.id)
connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()]
assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \
'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports))
input_value = node.in_port(0).data.get_value()
min_value = node.in_port(1).data.get_value()
max_value = node.in_port(2).data.get_value()
if input_value is not None and min_value is not None and max_value is not None:
assert np.all(max_value >= min_value), \
'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name)
node.out_port(0).data.set_value(np.clip(node.in_port(0).data.get_value(), min_value, max_value))
else:
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())

View File

@ -1,47 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.ops.clamp import AttributedClamp
from mo.utils.unittest.graph import build_graph
class TestClampOp(unittest.TestCase):
nodes_attributes = {
'node_1': {
'shape': np.array([227, 227, 227, 227])
},
'clamp_node': {
},
'node_3': {
'kind': 'data'
}
}
def test_clamp_op(self):
graph = build_graph(self.nodes_attributes,
[
('node_1', 'clamp_node'),
('clamp_node', 'node_3')
])
clamp_node = AttributedClamp(graph, self.nodes_attributes['clamp_node']).add_node()
self.assertEqual(clamp_node.type, 'Clamp')
self.assertEqual(clamp_node.op, 'AttributedClamp')
self.assertEqual(clamp_node.infer, copy_shape_infer)