Files
openvino/model-optimizer/mo/ops/lrn.py

109 lines
4.2 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.graph.graph import Graph
from mo.ops.op import Op
class LRN(Op):
op = 'LRN'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
assert 'alpha' in attrs, 'LRN operation should have `alpha` parameter set while creation'
assert 'beta' in attrs, 'LRN operation should have `beta` parameter set while creation'
assert 'bias' in attrs, 'LRN operation should have `bias` parameter set while creation'
assert 'size' in attrs, 'LRN operation should have `size` parameter set while creation'
assert 'region' not in attrs, \
'LRN operation should not have `region` parameter set while creation, please use AttributedLRN operation ' \
'instead or keep using LRN operation with region expressed as second `axis`-input'
super().__init__(graph, {
'type': self.op,
'op': self.op,
'version': 'opset1',
'infer': self.infer,
'in_ports_count': 2,
'out_ports_count': 1,
}, attrs)
def supported_attrs(self):
return ['alpha', 'beta', 'bias', 'size']
@staticmethod
def infer(node):
name = node.soft_get('name', node.id)
connected_inputs = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
assert len(connected_inputs) == 2 and 0 in connected_inputs and 1 in connected_inputs, \
'LRN should have 2 connected input ports, but it doesn`t for node: `{}`. Ports: {}' \
''.format(name, connected_inputs)
input_shape = node.in_port(0).data.get_shape()
assert input_shape is not None, 'Input shape is unknown for node {}'.format(name)
node.out_port(0).data.set_shape(input_shape)
class AttributedLRN(Op):
op = 'AttributedLRN'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
assert 'alpha' in attrs, 'AttributedLRN operation should have `alpha` parameter set while creation'
assert 'beta' in attrs, 'AttributedLRN operation should have `beta` parameter set while creation'
assert 'local_size' in attrs, 'AttributedLRN operation should have `local_size` parameter set while creation'
super().__init__(graph, {
'op': self.op,
'type': 'Norm',
'version': 'opset1',
'bias': 1,
'region': 'across',
'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
assert 'region' in self.attrs, 'AttributedLRN operation should have `region` parameter set while creation'
assert self.attrs['region'] in ['across', 'same'], \
'AttributedLRN operation should have `region` parameter set to `across` or `same`, but it is `{}`' \
''.format(self.attrs['region'])
def supported_attrs(self):
return [
'alpha',
'beta',
('local-size', lambda node: node.local_size),
'region' # deprecated in V10 attribute, but it is kept for V6 compatibility
]
@staticmethod
def infer(node):
name = node.soft_get('name', node.id)
connected_inputs = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
assert len(connected_inputs) == 1 and 0 in connected_inputs, \
'AttributedLRN should have 1 connected input port, but it doesn`t for node: `{}`. Ports: {}' \
''.format(name, connected_inputs)
input_shape = node.in_port(0).data.get_shape()
assert input_shape is not None, 'Input shape is unknown for node {}'.format(name)
node.out_port(0).data.set_shape(input_shape)