Re-implement onnx old-style extractors with extractor extensions (#3459)

* add class ConcatFrontExtractor for onnx

* add class ConcatFrontExtractor for onnx

* Delete concat from mo/front/onnx/extractor.py

* Add identity, reshape extractors classes for onnx

* import FrontExtractorOp

* Added BatchNormalizationExtractor

* Added extra line

* Fix import modules

* fix caffe bn.py and bn_test.py

* Fix BatchNormInference

* Modify convert_batch_norm

* Modify convert_batch_norm

* Modify bn_test.py

* Fix old comments BN->batchNormInference
This commit is contained in:
Eugeny Volosenkov 2020-12-10 09:24:24 +03:00 committed by GitHub
parent 6254b150c3
commit 6f512142b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 72 additions and 77 deletions

View File

@ -241,6 +241,7 @@ extensions/front/onnx/aten_ext.py
extensions/front/onnx/AttributedSliceToSlice.py
extensions/front/onnx/cast_ext.py
extensions/front/onnx/clip_ext.py
extensions/front/onnx/concat_ext.py
extensions/front/onnx/const_ext.py
extensions/front/onnx/constant_fill_ext.py
extensions/front/onnx/constant_of_shape_ext.py
@ -260,12 +261,14 @@ extensions/front/onnx/expand_ext.py
extensions/front/onnx/faster_rcnn.json
extensions/front/onnx/flatten_ext.py
extensions/front/onnx/flattenONNX_to_reshape.py
extensions/front/onnx/fused_bn_ext.py
extensions/front/onnx/gather_ext.py
extensions/front/onnx/gathernd_ext.py
extensions/front/onnx/gemm_ext.py
extensions/front/onnx/group_norm_ext.py
extensions/front/onnx/gru_ext.py
extensions/front/onnx/hard_sigmoid_ext.py
extensions/front/onnx/identity_ext.py
extensions/front/onnx/image_scaler_ext.py
extensions/front/onnx/instance_normalization_ext.py
extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py
@ -301,6 +304,7 @@ extensions/front/onnx/quantize_linear_resolver.py
extensions/front/onnx/range_ext.py
extensions/front/onnx/reduce_ext.py
extensions/front/onnx/remove_filtering_boxes_by_size.py
extensions/front/onnx/reshape_ext.py
extensions/front/onnx/resize_ext.py
extensions/front/onnx/resize_to_interpolate.py
extensions/front/onnx/reverse_sequence_ext.py
@ -602,9 +606,9 @@ extensions/ops/argmax.py
extensions/ops/assert_op.py
extensions/ops/aten.py
extensions/ops/axpy.py
extensions/ops/BatchNormInference.py
extensions/ops/binarization.py
extensions/ops/BlockLSTM.py
extensions/ops/bn.py
extensions/ops/box_nms.py
extensions/ops/bucketize.py
extensions/ops/Cast.py
@ -843,10 +847,6 @@ mo/front/mxnet/register_custom_ops.py
mo/front/onnx/__init__.py
mo/front/onnx/extractor.py
mo/front/onnx/extractors/__init__.py
mo/front/onnx/extractors/concat.py
mo/front/onnx/extractors/eltwise.py
mo/front/onnx/extractors/fused_bn.py
mo/front/onnx/extractors/reshape.py
mo/front/onnx/extractors/utils.py
mo/front/onnx/loader.py
mo/front/onnx/register_custom_ops.py

View File

@ -27,7 +27,7 @@ class BNToScaleShift(FrontReplacementOp):
"""
Replaces BN layer with ScaleShift.
"""
op = "BN"
op = "batchNormInference"
enabled = True
def replace_op(self, graph: Graph, node: Node):

View File

@ -47,7 +47,7 @@ class TestBNReplacer(unittest.TestCase):
FakeParam('data', shift)])
nodes = [
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
('bn', {'type': 'BN', 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}),
('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}),
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
]
edges = [

View File

@ -14,21 +14,18 @@
limitations under the License.
"""
from mo.front.common.partial_infer.eltwise import eltwise_infer
from mo.front.onnx.extractors.utils import onnx_attr
from mo.front.extractor import FrontExtractorOp
from mo.ops.concat import Concat
class ConcatFrontExtractor(FrontExtractorOp):
op = 'Concat'
enabled = True
def tf_eltwise_ext(pb, op=None, attrs=None):
"""
Generic eltwise extractor that supports n-ary operations.
It supports reasonable broadcast semantics from TF/NumPy
"""
res = {
'infer': lambda node: eltwise_infer(node, op)
}
if attrs is not None:
res.update(attrs)
return res
def make_tf_eltwise(op, attrs=None):
return lambda node: tf_eltwise_ext(node, op, attrs)
@classmethod
def extract(cls, node):
mapping_rule = {
'axis': onnx_attr(node, 'axis', 'i', default=0)
}
Concat.update_node_stat(node, mapping_rule)
return cls.enabled

View File

@ -15,19 +15,21 @@
"""
import logging as log
from mo.front.common.partial_infer.elemental import copy_shape_infer
from extensions.ops.BatchNormInference import BatchNormInference
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
def tf_fused_bn_extractor(node):
pb = node.pb
# This statement covers different opset versions
if onnx_attr(node, 'is_test', 'i', None) == 0:
log.error('FusedBatchNorm doesn\'t support is_test=False')
return None
return {
'data_format': 'NCHW',
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5),
'infer': copy_shape_infer,
}
class BatchNormalizationExtractor(FrontExtractorOp):
op = 'BatchNormalization'
enabled = True
@classmethod
def extract(cls, node):
attr_dict = {
'data_format': 'NCHW',
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5),
}
BatchNormInference.update_node_stat(node, attr_dict)
return cls.enabled

View File

@ -13,14 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.partial_infer.concat import concat_infer
from mo.front.onnx.extractors.utils import onnx_attr
from mo.front.extractor import FrontExtractorOp
from extensions.ops.identity import Identity
def concat_ext(node):
return {
'type': "Concat",
'axis': onnx_attr(node, 'axis', 'i', default=0),
'infer': concat_infer
}
class IdentityFrontExtractor(FrontExtractorOp):
op = 'Identity'
enabled = True
@classmethod
def extract(cls, node):
Identity.update_node_stat(node)
return cls.enabled

View File

@ -16,19 +16,20 @@
import numpy as np
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.ops.reshape import Reshape
class ReshapeFrontExtractor(FrontExtractorOp):
op = 'Reshape'
enabled = True
def onnx_reshape_ext(node):
''' Extract ONNX Reshape op of different versions.
Support both latest Reshape and Reshape-1.
The first one has 2 arguments, Reshape-1 has one input and shape is coded in attribute.
'''
dim = onnx_attr(node, 'shape', 'ints', None)
if dim is not None:
dim = np.array(dim, dtype=np.int64)
Reshape.update_node_stat(node, {'dim': dim})
else:
Reshape.update_node_stat(node)
return node.graph.node[node.id]
@classmethod
def extract(cls, node):
dim = onnx_attr(node, 'shape', 'ints', None)
if dim is not None:
dim = np.array(dim, dtype=np.int64)
Reshape.update_node_stat(node, {'dim': dim})
else:
Reshape.update_node_stat(node)
return cls.enabled

View File

@ -61,7 +61,7 @@ class Fusing(MiddleReplacementPattern):
for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))
# Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
# IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
# IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift
for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm)
if fw == 'caffe':

View File

@ -18,18 +18,22 @@ from mo.graph.graph import Graph
from mo.ops.op import Op
class BNOp(Op):
class BatchNormInference(Op):
"""
Empty Op for BN layer. It will be replaced by BNToScaleShift FrontReplacer
BatchNormInference will be replaced by BNToScaleShift FrontReplacer for Caffe or convert_batch_norm
function for other frameworks
"""
op = 'BN'
enabled = True
op = 'batchNormInference'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': None,
'op': __class__.op,
'op': self.op,
'in_ports_count': 5,
'out_ports_count': 1,
'infer': None
'infer': self.infer
}, attrs)
@staticmethod
def infer(node):
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())

View File

@ -14,11 +14,6 @@
limitations under the License.
"""
from mo.front.onnx.extractors.concat import concat_ext
from mo.front.onnx.extractors.eltwise import make_tf_eltwise
from mo.front.onnx.extractors.fused_bn import tf_fused_bn_extractor
from mo.front.onnx.extractors.reshape import onnx_reshape_ext
from mo.graph.graph import Node
@ -26,12 +21,7 @@ def node_pb_arg(pb_extractor: callable):
return lambda node: pb_extractor(node.pb)
onnx_op_extractors = {
'BatchNormalization': tf_fused_bn_extractor,
'Concat': concat_ext,
'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
'Reshape': onnx_reshape_ext,
}
onnx_op_extractors = {}
def common_onnx_fields(node: Node):

View File

@ -40,7 +40,7 @@ def convert_batch_norm(graph: Graph):
nodes = graph.get_op_nodes()
for node in nodes:
if node.has_valid('op') and (node.op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
'BatchNorm', 'BatchNormalization']):
'BatchNorm', 'BatchNormalization', 'batchNormInference']):
if any([node.in_port(i).data.get_value() is None for i in range(1, len(node.in_ports()))]):
log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(