Re-implement onnx old-style extractors with extractor extensions (#3459)
* add class ConcatFrontExtractor for onnx * add class ConcatFrontExtractor for onnx * Delete concat from mo/front/onnx/extractor.py * Add identity, reshape extractors classes for onnx * import FrontExtractorOp * Added BatchNormalizationExtractor * Added extra line * Fix import modules * fix caffe bn.py and bn_test.py * Fix BatchNormInference * Modify convert_batch_norm * Modify convert_batch_norm * Modify bn_test.py * Fix old comments BN->batchNormInference
This commit is contained in:
parent
6254b150c3
commit
6f512142b6
@ -241,6 +241,7 @@ extensions/front/onnx/aten_ext.py
|
||||
extensions/front/onnx/AttributedSliceToSlice.py
|
||||
extensions/front/onnx/cast_ext.py
|
||||
extensions/front/onnx/clip_ext.py
|
||||
extensions/front/onnx/concat_ext.py
|
||||
extensions/front/onnx/const_ext.py
|
||||
extensions/front/onnx/constant_fill_ext.py
|
||||
extensions/front/onnx/constant_of_shape_ext.py
|
||||
@ -260,12 +261,14 @@ extensions/front/onnx/expand_ext.py
|
||||
extensions/front/onnx/faster_rcnn.json
|
||||
extensions/front/onnx/flatten_ext.py
|
||||
extensions/front/onnx/flattenONNX_to_reshape.py
|
||||
extensions/front/onnx/fused_bn_ext.py
|
||||
extensions/front/onnx/gather_ext.py
|
||||
extensions/front/onnx/gathernd_ext.py
|
||||
extensions/front/onnx/gemm_ext.py
|
||||
extensions/front/onnx/group_norm_ext.py
|
||||
extensions/front/onnx/gru_ext.py
|
||||
extensions/front/onnx/hard_sigmoid_ext.py
|
||||
extensions/front/onnx/identity_ext.py
|
||||
extensions/front/onnx/image_scaler_ext.py
|
||||
extensions/front/onnx/instance_normalization_ext.py
|
||||
extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py
|
||||
@ -301,6 +304,7 @@ extensions/front/onnx/quantize_linear_resolver.py
|
||||
extensions/front/onnx/range_ext.py
|
||||
extensions/front/onnx/reduce_ext.py
|
||||
extensions/front/onnx/remove_filtering_boxes_by_size.py
|
||||
extensions/front/onnx/reshape_ext.py
|
||||
extensions/front/onnx/resize_ext.py
|
||||
extensions/front/onnx/resize_to_interpolate.py
|
||||
extensions/front/onnx/reverse_sequence_ext.py
|
||||
@ -602,9 +606,9 @@ extensions/ops/argmax.py
|
||||
extensions/ops/assert_op.py
|
||||
extensions/ops/aten.py
|
||||
extensions/ops/axpy.py
|
||||
extensions/ops/BatchNormInference.py
|
||||
extensions/ops/binarization.py
|
||||
extensions/ops/BlockLSTM.py
|
||||
extensions/ops/bn.py
|
||||
extensions/ops/box_nms.py
|
||||
extensions/ops/bucketize.py
|
||||
extensions/ops/Cast.py
|
||||
@ -843,10 +847,6 @@ mo/front/mxnet/register_custom_ops.py
|
||||
mo/front/onnx/__init__.py
|
||||
mo/front/onnx/extractor.py
|
||||
mo/front/onnx/extractors/__init__.py
|
||||
mo/front/onnx/extractors/concat.py
|
||||
mo/front/onnx/extractors/eltwise.py
|
||||
mo/front/onnx/extractors/fused_bn.py
|
||||
mo/front/onnx/extractors/reshape.py
|
||||
mo/front/onnx/extractors/utils.py
|
||||
mo/front/onnx/loader.py
|
||||
mo/front/onnx/register_custom_ops.py
|
||||
|
@ -27,7 +27,7 @@ class BNToScaleShift(FrontReplacementOp):
|
||||
"""
|
||||
Replaces BN layer with ScaleShift.
|
||||
"""
|
||||
op = "BN"
|
||||
op = "batchNormInference"
|
||||
enabled = True
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
|
@ -47,7 +47,7 @@ class TestBNReplacer(unittest.TestCase):
|
||||
FakeParam('data', shift)])
|
||||
nodes = [
|
||||
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
||||
('bn', {'type': 'BN', 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}),
|
||||
('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}),
|
||||
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
||||
]
|
||||
edges = [
|
||||
|
@ -14,21 +14,18 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.eltwise import eltwise_infer
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.ops.concat import Concat
|
||||
|
||||
class ConcatFrontExtractor(FrontExtractorOp):
|
||||
op = 'Concat'
|
||||
enabled = True
|
||||
|
||||
def tf_eltwise_ext(pb, op=None, attrs=None):
|
||||
"""
|
||||
Generic eltwise extractor that supports n-ary operations.
|
||||
It supports reasonable broadcast semantics from TF/NumPy
|
||||
"""
|
||||
res = {
|
||||
'infer': lambda node: eltwise_infer(node, op)
|
||||
}
|
||||
if attrs is not None:
|
||||
res.update(attrs)
|
||||
return res
|
||||
|
||||
|
||||
def make_tf_eltwise(op, attrs=None):
|
||||
return lambda node: tf_eltwise_ext(node, op, attrs)
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
mapping_rule = {
|
||||
'axis': onnx_attr(node, 'axis', 'i', default=0)
|
||||
}
|
||||
Concat.update_node_stat(node, mapping_rule)
|
||||
return cls.enabled
|
@ -15,19 +15,21 @@
|
||||
"""
|
||||
import logging as log
|
||||
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from extensions.ops.BatchNormInference import BatchNormInference
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
|
||||
|
||||
def tf_fused_bn_extractor(node):
|
||||
pb = node.pb
|
||||
# This statement covers different opset versions
|
||||
if onnx_attr(node, 'is_test', 'i', None) == 0:
|
||||
log.error('FusedBatchNorm doesn\'t support is_test=False')
|
||||
return None
|
||||
|
||||
return {
|
||||
'data_format': 'NCHW',
|
||||
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5),
|
||||
'infer': copy_shape_infer,
|
||||
}
|
||||
class BatchNormalizationExtractor(FrontExtractorOp):
|
||||
op = 'BatchNormalization'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
attr_dict = {
|
||||
'data_format': 'NCHW',
|
||||
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5),
|
||||
}
|
||||
BatchNormInference.update_node_stat(node, attr_dict)
|
||||
return cls.enabled
|
@ -13,14 +13,15 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.concat import concat_infer
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from extensions.ops.identity import Identity
|
||||
|
||||
|
||||
def concat_ext(node):
|
||||
return {
|
||||
'type': "Concat",
|
||||
'axis': onnx_attr(node, 'axis', 'i', default=0),
|
||||
'infer': concat_infer
|
||||
}
|
||||
class IdentityFrontExtractor(FrontExtractorOp):
|
||||
op = 'Identity'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
Identity.update_node_stat(node)
|
||||
return cls.enabled
|
@ -16,19 +16,20 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.ops.reshape import Reshape
|
||||
|
||||
class ReshapeFrontExtractor(FrontExtractorOp):
|
||||
op = 'Reshape'
|
||||
enabled = True
|
||||
|
||||
def onnx_reshape_ext(node):
|
||||
''' Extract ONNX Reshape op of different versions.
|
||||
Support both latest Reshape and Reshape-1.
|
||||
The first one has 2 arguments, Reshape-1 has one input and shape is coded in attribute.
|
||||
'''
|
||||
dim = onnx_attr(node, 'shape', 'ints', None)
|
||||
if dim is not None:
|
||||
dim = np.array(dim, dtype=np.int64)
|
||||
Reshape.update_node_stat(node, {'dim': dim})
|
||||
else:
|
||||
Reshape.update_node_stat(node)
|
||||
return node.graph.node[node.id]
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
dim = onnx_attr(node, 'shape', 'ints', None)
|
||||
if dim is not None:
|
||||
dim = np.array(dim, dtype=np.int64)
|
||||
Reshape.update_node_stat(node, {'dim': dim})
|
||||
else:
|
||||
Reshape.update_node_stat(node)
|
||||
return cls.enabled
|
@ -61,7 +61,7 @@ class Fusing(MiddleReplacementPattern):
|
||||
for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))
|
||||
|
||||
# Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
|
||||
# IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
|
||||
# IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift
|
||||
for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm)
|
||||
|
||||
if fw == 'caffe':
|
||||
|
@ -18,18 +18,22 @@ from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class BNOp(Op):
|
||||
class BatchNormInference(Op):
|
||||
"""
|
||||
Empty Op for BN layer. It will be replaced by BNToScaleShift FrontReplacer
|
||||
BatchNormInference will be replaced by BNToScaleShift FrontReplacer for Caffe or convert_batch_norm
|
||||
function for other frameworks
|
||||
"""
|
||||
op = 'BN'
|
||||
enabled = True
|
||||
op = 'batchNormInference'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': __class__.op,
|
||||
'op': self.op,
|
||||
'in_ports_count': 5,
|
||||
'out_ports_count': 1,
|
||||
'infer': None
|
||||
'infer': self.infer
|
||||
}, attrs)
|
||||
@staticmethod
|
||||
def infer(node):
|
||||
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
@ -14,11 +14,6 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from mo.front.onnx.extractors.concat import concat_ext
|
||||
from mo.front.onnx.extractors.eltwise import make_tf_eltwise
|
||||
from mo.front.onnx.extractors.fused_bn import tf_fused_bn_extractor
|
||||
from mo.front.onnx.extractors.reshape import onnx_reshape_ext
|
||||
from mo.graph.graph import Node
|
||||
|
||||
|
||||
@ -26,12 +21,7 @@ def node_pb_arg(pb_extractor: callable):
|
||||
return lambda node: pb_extractor(node.pb)
|
||||
|
||||
|
||||
onnx_op_extractors = {
|
||||
'BatchNormalization': tf_fused_bn_extractor,
|
||||
'Concat': concat_ext,
|
||||
'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
|
||||
'Reshape': onnx_reshape_ext,
|
||||
}
|
||||
onnx_op_extractors = {}
|
||||
|
||||
|
||||
def common_onnx_fields(node: Node):
|
||||
|
@ -40,7 +40,7 @@ def convert_batch_norm(graph: Graph):
|
||||
nodes = graph.get_op_nodes()
|
||||
for node in nodes:
|
||||
if node.has_valid('op') and (node.op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
|
||||
'BatchNorm', 'BatchNormalization']):
|
||||
'BatchNorm', 'BatchNormalization', 'batchNormInference']):
|
||||
|
||||
if any([node.in_port(i).data.get_value() is None for i in range(1, len(node.in_ports()))]):
|
||||
log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(
|
||||
|
Loading…
Reference in New Issue
Block a user