Files
openvino/model-optimizer/extensions/front/mxnet/elementwise_ext.py

417 lines
9.8 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.elementwise import Mul, Sub, Add, Maximum, Minimum, Div, Greater, GreaterEqual, Equal, Less, \
LessEqual, Pow, NotEqual, LogicalAnd, LogicalOr
from mo.front.extractor import FrontExtractorOp
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from mo.graph.graph import Node
from mo.ops.eltwise_n import EltwiseNAdd
from mo.ops.power import AttributedPower
class PlusExtractor(FrontExtractorOp):
op = '_Plus'
enabled = True
@classmethod
def extract(cls, node: Node):
Add.update_node_stat(node)
return cls.enabled
class BroadcastAddFrontExtractor(FrontExtractorOp):
op = 'broadcast_add'
enabled = True
@classmethod
def extract(cls, node):
Add.update_node_stat(node)
return cls.enabled
class BroadcastDivFrontExtractor(FrontExtractorOp):
op = 'broadcast_div'
enabled = True
@classmethod
def extract(cls, node):
Div.update_node_stat(node)
return cls.enabled
class BroadcastSubFrontExtractor(FrontExtractorOp):
op = 'broadcast_sub'
enabled = True
@classmethod
def extract(cls, node):
Sub.update_node_stat(node)
return cls.enabled
class ElementwiseAddExtractor(FrontExtractorOp):
op = 'elemwise_add'
enabled = True
@classmethod
def extract(cls, node: Node):
Add.update_node_stat(node)
return cls.enabled
class ElementWiseSum(FrontExtractorOp):
op = 'ElementWiseSum'
enabled = True
@classmethod
def extract(cls, node: Node):
EltwiseNAdd.update_node_stat(node)
return cls.enabled
class AddNExtractor(FrontExtractorOp):
op = 'add_n'
enabled = True
@classmethod
def extract(cls, node: Node):
EltwiseNAdd.update_node_stat(node)
return cls.enabled
class ElementwiseMulExtractor(FrontExtractorOp):
op = 'elemwise_mul'
enabled = True
@classmethod
def extract(cls, node: Node):
Mul.update_node_stat(node)
return cls.enabled
class BroadcastMulFrontExtractor(FrontExtractorOp):
op = 'broadcast_mul'
enabled = True
@classmethod
def extract(cls, node):
Mul.update_node_stat(node)
return cls.enabled
class ElemwiseSubFrontExtractor(FrontExtractorOp):
op = 'elemwise_sub'
enabled = True
@classmethod
def extract(cls, node):
Sub.update_node_stat(node, {})
return cls.enabled
class ElemwiseDivFrontExtractor(FrontExtractorOp):
op = 'elemwise_div'
enabled = True
@classmethod
def extract(cls, node):
Div.update_node_stat(node, {})
return cls.enabled
class BroadcastMaximumFrontExtractor(FrontExtractorOp):
op = 'broadcast_maximum'
enabled = True
@classmethod
def extract(cls, node):
Maximum.update_node_stat(node)
return cls.enabled
class BroadcastMinimumFrontExtractor(FrontExtractorOp):
op = 'broadcast_minimum'
enabled = True
@classmethod
def extract(cls, node):
Minimum.update_node_stat(node)
return cls.enabled
class BroadcastGreaterFrontExtractor(FrontExtractorOp):
op = 'broadcast_greater'
enabled = True
@classmethod
def extract(cls, node):
Greater.update_node_stat(node)
return cls.enabled
class BroadcastGreaterEqualFrontExtractor(FrontExtractorOp):
op = 'broadcast_greater_equal'
enabled = True
@classmethod
def extract(cls, node):
GreaterEqual.update_node_stat(node)
return cls.enabled
class BroadcastEqualFrontExtractor(FrontExtractorOp):
op = 'broadcast_equal'
enabled = True
@classmethod
def extract(cls, node):
Equal.update_node_stat(node)
return cls.enabled
class BroadcastNotEqualFrontExtractor(FrontExtractorOp):
op = 'broadcast_not_equal'
enabled = True
@classmethod
def extract(cls, node):
NotEqual.update_node_stat(node)
return cls.enabled
class BroadcastLesserFrontExtractor(FrontExtractorOp):
op = 'broadcast_lesser'
enabled = True
@classmethod
def extract(cls, node):
Less.update_node_stat(node)
return cls.enabled
class BroadcastLesserEqualFrontExtractor(FrontExtractorOp):
op = 'broadcast_lesser_equal'
enabled = True
@classmethod
def extract(cls, node):
LessEqual.update_node_stat(node)
return cls.enabled
class BroadcastPowerFrontExtractor(FrontExtractorOp):
op = 'broadcast_power'
enabled = True
@classmethod
def extract(cls, node):
Pow.update_node_stat(node)
return cls.enabled
class BroadcastLogicalAndFrontExtractor(FrontExtractorOp):
op = 'broadcast_logical_and'
enabled = True
@classmethod
def extract(cls, node):
LogicalAnd.update_node_stat(node)
return cls.enabled
class BroadcastLogicalOrFrontExtractor(FrontExtractorOp):
op = 'broadcast_logical_or'
enabled = True
@classmethod
def extract(cls, node):
LogicalOr.update_node_stat(node)
return cls.enabled
class MaximumFrontExtractor(FrontExtractorOp):
op = '_maximum'
enabled = True
@classmethod
def extract(cls, node):
Maximum.update_node_stat(node)
return cls.enabled
class MinimumFrontExtractor(FrontExtractorOp):
op = '_minimum'
enabled = True
@classmethod
def extract(cls, node):
Minimum.update_node_stat(node)
return cls.enabled
class PlusScalarFrontExtractor(FrontExtractorOp):
op = '_plus_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 0.0)], dtype=np.float32)
return cls.enabled
class MinusScalarFrontExtractor(FrontExtractorOp):
op = '_minus_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 0.0)])
return cls.enabled
class MulScalarFrontExtractor(FrontExtractorOp):
op = '_mul_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 1.0)], dtype=np.float32)
return cls.enabled
class DivScalarFrontExtractor(FrontExtractorOp):
op = '_div_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = attrs.float('scalar', 1.0)
return cls.enabled
class GreaterScalarFrontExtractor(FrontExtractorOp):
op = '_greater_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 1.0)])
return cls.enabled
class GreaterEqualScalarFrontExtractor(FrontExtractorOp):
op = '_greater_equal_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 1.0)])
return cls.enabled
class EqualScalarFrontExtractor(FrontExtractorOp):
op = '_equal_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 1.0)])
return cls.enabled
class NotEqualScalarFrontExtractor(FrontExtractorOp):
op = '_not_equal_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 1.0)])
return cls.enabled
class LesserScalarFrontExtractor(FrontExtractorOp):
op = '_lesser_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 1.0)])
return cls.enabled
class LesserEqualScalarFrontExtractor(FrontExtractorOp):
op = '_lesser_equal_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = np.array([attrs.float('scalar', 1.0)])
return cls.enabled
class MinimumScalarFrontExtractor(FrontExtractorOp):
op = '_minimum_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = attrs.float('scalar', 1.0)
return cls.enabled
class MaximumScalarFrontExtractor(FrontExtractorOp):
op = '_maximum_scalar'
enabled = True
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
node['scalar'] = attrs.float('scalar', 1.0)
return cls.enabled
class ZerosFrontExtractor(FrontExtractorOp):
op = 'zeros_like'
enabled = True
@classmethod
def extract(cls, node):
AttributedPower.update_node_stat(node, {'scale': 0})
return cls.enabled
class OnesFrontExtractor(FrontExtractorOp):
op = 'ones_like'
enabled = True
@classmethod
def extract(cls, node):
AttributedPower.update_node_stat(node, {'scale': 0, 'shift': 1})
return cls.enabled