Re-implement onnx old-style extractors with extractor extensions (#3459)

* add class ConcatFrontExtractor for onnx

* add class ConcatFrontExtractor for onnx

* Delete concat from mo/front/onnx/extractor.py

* Add identity, reshape extractors classes for onnx

* import FrontExtractorOp

* Added BatchNormalizationExtractor

* Added extra line

* Fix import modules

* fix caffe bn.py and bn_test.py

* Fix BatchNormInference

* Modify convert_batch_norm

* Modify convert_batch_norm

* Modify bn_test.py

* Fix old comments BN->batchNormInference
This commit is contained in:
Eugeny Volosenkov 2020-12-10 09:24:24 +03:00 committed by GitHub
parent 6254b150c3
commit 6f512142b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 72 additions and 77 deletions

View File

@ -241,6 +241,7 @@ extensions/front/onnx/aten_ext.py
extensions/front/onnx/AttributedSliceToSlice.py extensions/front/onnx/AttributedSliceToSlice.py
extensions/front/onnx/cast_ext.py extensions/front/onnx/cast_ext.py
extensions/front/onnx/clip_ext.py extensions/front/onnx/clip_ext.py
extensions/front/onnx/concat_ext.py
extensions/front/onnx/const_ext.py extensions/front/onnx/const_ext.py
extensions/front/onnx/constant_fill_ext.py extensions/front/onnx/constant_fill_ext.py
extensions/front/onnx/constant_of_shape_ext.py extensions/front/onnx/constant_of_shape_ext.py
@ -260,12 +261,14 @@ extensions/front/onnx/expand_ext.py
extensions/front/onnx/faster_rcnn.json extensions/front/onnx/faster_rcnn.json
extensions/front/onnx/flatten_ext.py extensions/front/onnx/flatten_ext.py
extensions/front/onnx/flattenONNX_to_reshape.py extensions/front/onnx/flattenONNX_to_reshape.py
extensions/front/onnx/fused_bn_ext.py
extensions/front/onnx/gather_ext.py extensions/front/onnx/gather_ext.py
extensions/front/onnx/gathernd_ext.py extensions/front/onnx/gathernd_ext.py
extensions/front/onnx/gemm_ext.py extensions/front/onnx/gemm_ext.py
extensions/front/onnx/group_norm_ext.py extensions/front/onnx/group_norm_ext.py
extensions/front/onnx/gru_ext.py extensions/front/onnx/gru_ext.py
extensions/front/onnx/hard_sigmoid_ext.py extensions/front/onnx/hard_sigmoid_ext.py
extensions/front/onnx/identity_ext.py
extensions/front/onnx/image_scaler_ext.py extensions/front/onnx/image_scaler_ext.py
extensions/front/onnx/instance_normalization_ext.py extensions/front/onnx/instance_normalization_ext.py
extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py
@ -301,6 +304,7 @@ extensions/front/onnx/quantize_linear_resolver.py
extensions/front/onnx/range_ext.py extensions/front/onnx/range_ext.py
extensions/front/onnx/reduce_ext.py extensions/front/onnx/reduce_ext.py
extensions/front/onnx/remove_filtering_boxes_by_size.py extensions/front/onnx/remove_filtering_boxes_by_size.py
extensions/front/onnx/reshape_ext.py
extensions/front/onnx/resize_ext.py extensions/front/onnx/resize_ext.py
extensions/front/onnx/resize_to_interpolate.py extensions/front/onnx/resize_to_interpolate.py
extensions/front/onnx/reverse_sequence_ext.py extensions/front/onnx/reverse_sequence_ext.py
@ -602,9 +606,9 @@ extensions/ops/argmax.py
extensions/ops/assert_op.py extensions/ops/assert_op.py
extensions/ops/aten.py extensions/ops/aten.py
extensions/ops/axpy.py extensions/ops/axpy.py
extensions/ops/BatchNormInference.py
extensions/ops/binarization.py extensions/ops/binarization.py
extensions/ops/BlockLSTM.py extensions/ops/BlockLSTM.py
extensions/ops/bn.py
extensions/ops/box_nms.py extensions/ops/box_nms.py
extensions/ops/bucketize.py extensions/ops/bucketize.py
extensions/ops/Cast.py extensions/ops/Cast.py
@ -843,10 +847,6 @@ mo/front/mxnet/register_custom_ops.py
mo/front/onnx/__init__.py mo/front/onnx/__init__.py
mo/front/onnx/extractor.py mo/front/onnx/extractor.py
mo/front/onnx/extractors/__init__.py mo/front/onnx/extractors/__init__.py
mo/front/onnx/extractors/concat.py
mo/front/onnx/extractors/eltwise.py
mo/front/onnx/extractors/fused_bn.py
mo/front/onnx/extractors/reshape.py
mo/front/onnx/extractors/utils.py mo/front/onnx/extractors/utils.py
mo/front/onnx/loader.py mo/front/onnx/loader.py
mo/front/onnx/register_custom_ops.py mo/front/onnx/register_custom_ops.py
@ -1011,4 +1011,4 @@ requirements_kaldi.txt
requirements_mxnet.txt requirements_mxnet.txt
requirements_onnx.txt requirements_onnx.txt
requirements_tf.txt requirements_tf.txt
requirements_tf2.txt requirements_tf2.txt

View File

@ -27,7 +27,7 @@ class BNToScaleShift(FrontReplacementOp):
""" """
Replaces BN layer with ScaleShift. Replaces BN layer with ScaleShift.
""" """
op = "BN" op = "batchNormInference"
enabled = True enabled = True
def replace_op(self, graph: Graph, node: Node): def replace_op(self, graph: Graph, node: Node):

View File

@ -47,7 +47,7 @@ class TestBNReplacer(unittest.TestCase):
FakeParam('data', shift)]) FakeParam('data', shift)])
nodes = [ nodes = [
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}), ('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
('bn', {'type': 'BN', 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}), ('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}),
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}), ('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
] ]
edges = [ edges = [

View File

@ -14,21 +14,18 @@
limitations under the License. limitations under the License.
""" """
from mo.front.common.partial_infer.eltwise import eltwise_infer from mo.front.onnx.extractors.utils import onnx_attr
from mo.front.extractor import FrontExtractorOp
from mo.ops.concat import Concat
class ConcatFrontExtractor(FrontExtractorOp):
op = 'Concat'
enabled = True
@classmethod
def tf_eltwise_ext(pb, op=None, attrs=None): def extract(cls, node):
""" mapping_rule = {
Generic eltwise extractor that supports n-ary operations. 'axis': onnx_attr(node, 'axis', 'i', default=0)
It supports reasonable broadcast semantics from TF/NumPy }
""" Concat.update_node_stat(node, mapping_rule)
res = { return cls.enabled
'infer': lambda node: eltwise_infer(node, op)
}
if attrs is not None:
res.update(attrs)
return res
def make_tf_eltwise(op, attrs=None):
return lambda node: tf_eltwise_ext(node, op, attrs)

View File

@ -15,19 +15,21 @@
""" """
import logging as log import logging as log
from mo.front.common.partial_infer.elemental import copy_shape_infer from extensions.ops.BatchNormInference import BatchNormInference
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr from mo.front.onnx.extractors.utils import onnx_attr
def tf_fused_bn_extractor(node):
pb = node.pb
# This statement covers different opset versions
if onnx_attr(node, 'is_test', 'i', None) == 0:
log.error('FusedBatchNorm doesn\'t support is_test=False')
return None
return { class BatchNormalizationExtractor(FrontExtractorOp):
'data_format': 'NCHW', op = 'BatchNormalization'
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5), enabled = True
'infer': copy_shape_infer,
} @classmethod
def extract(cls, node):
attr_dict = {
'data_format': 'NCHW',
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5),
}
BatchNormInference.update_node_stat(node, attr_dict)
return cls.enabled

View File

@ -13,14 +13,15 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from mo.front.extractor import FrontExtractorOp
from mo.front.common.partial_infer.concat import concat_infer from extensions.ops.identity import Identity
from mo.front.onnx.extractors.utils import onnx_attr
def concat_ext(node): class IdentityFrontExtractor(FrontExtractorOp):
return { op = 'Identity'
'type': "Concat", enabled = True
'axis': onnx_attr(node, 'axis', 'i', default=0),
'infer': concat_infer @classmethod
} def extract(cls, node):
Identity.update_node_stat(node)
return cls.enabled

View File

@ -16,19 +16,20 @@
import numpy as np import numpy as np
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr from mo.front.onnx.extractors.utils import onnx_attr
from mo.ops.reshape import Reshape from mo.ops.reshape import Reshape
class ReshapeFrontExtractor(FrontExtractorOp):
op = 'Reshape'
enabled = True
def onnx_reshape_ext(node): @classmethod
''' Extract ONNX Reshape op of different versions. def extract(cls, node):
Support both latest Reshape and Reshape-1. dim = onnx_attr(node, 'shape', 'ints', None)
The first one has 2 arguments, Reshape-1 has one input and shape is coded in attribute. if dim is not None:
''' dim = np.array(dim, dtype=np.int64)
dim = onnx_attr(node, 'shape', 'ints', None) Reshape.update_node_stat(node, {'dim': dim})
if dim is not None: else:
dim = np.array(dim, dtype=np.int64) Reshape.update_node_stat(node)
Reshape.update_node_stat(node, {'dim': dim}) return cls.enabled
else:
Reshape.update_node_stat(node)
return node.graph.node[node.id]

View File

@ -61,7 +61,7 @@ class Fusing(MiddleReplacementPattern):
for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing)) for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))
# Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
# IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift # IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift
for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm) for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm)
if fw == 'caffe': if fw == 'caffe':

View File

@ -18,18 +18,22 @@ from mo.graph.graph import Graph
from mo.ops.op import Op from mo.ops.op import Op
class BNOp(Op): class BatchNormInference(Op):
""" """
Empty Op for BN layer. It will be replaced by BNToScaleShift FrontReplacer BatchNormInference will be replaced by BNToScaleShift FrontReplacer for Caffe or convert_batch_norm
function for other frameworks
""" """
op = 'BN' op = 'batchNormInference'
enabled = True enabled = False
def __init__(self, graph: Graph, attrs: dict): def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, { super().__init__(graph, {
'type': None, 'type': None,
'op': __class__.op, 'op': self.op,
'in_ports_count': 5, 'in_ports_count': 5,
'out_ports_count': 1, 'out_ports_count': 1,
'infer': None 'infer': self.infer
}, attrs) }, attrs)
@staticmethod
def infer(node):
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())

View File

@ -14,11 +14,6 @@
limitations under the License. limitations under the License.
""" """
from mo.front.onnx.extractors.concat import concat_ext
from mo.front.onnx.extractors.eltwise import make_tf_eltwise
from mo.front.onnx.extractors.fused_bn import tf_fused_bn_extractor
from mo.front.onnx.extractors.reshape import onnx_reshape_ext
from mo.graph.graph import Node from mo.graph.graph import Node
@ -26,12 +21,7 @@ def node_pb_arg(pb_extractor: callable):
return lambda node: pb_extractor(node.pb) return lambda node: pb_extractor(node.pb)
onnx_op_extractors = { onnx_op_extractors = {}
'BatchNormalization': tf_fused_bn_extractor,
'Concat': concat_ext,
'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
'Reshape': onnx_reshape_ext,
}
def common_onnx_fields(node: Node): def common_onnx_fields(node: Node):

View File

@ -40,7 +40,7 @@ def convert_batch_norm(graph: Graph):
nodes = graph.get_op_nodes() nodes = graph.get_op_nodes()
for node in nodes: for node in nodes:
if node.has_valid('op') and (node.op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3', if node.has_valid('op') and (node.op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
'BatchNorm', 'BatchNormalization']): 'BatchNorm', 'BatchNormalization', 'batchNormInference']):
if any([node.in_port(i).data.get_value() is None for i in range(1, len(node.in_ports()))]): if any([node.in_port(i).data.get_value() is None for i in range(1, len(node.in_ports()))]):
log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format( log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(