Re-implement onnx old-style extractors with extractor extensions (#3459)
* add class ConcatFrontExtractor for onnx * add class ConcatFrontExtractor for onnx * Delete concat from mo/front/onnx/extractor.py * Add identity, reshape extractors classes for onnx * import FrontExtractorOp * Added BatchNormalizationExtractor * Added extra line * Fix import modules * fix caffe bn.py and bn_test.py * Fix BatchNormInference * Modify convert_batch_norm * Modify convert_batch_norm * Modify bn_test.py * Fix old comments BN->batchNormInference
This commit is contained in:
parent
6254b150c3
commit
6f512142b6
@ -241,6 +241,7 @@ extensions/front/onnx/aten_ext.py
|
|||||||
extensions/front/onnx/AttributedSliceToSlice.py
|
extensions/front/onnx/AttributedSliceToSlice.py
|
||||||
extensions/front/onnx/cast_ext.py
|
extensions/front/onnx/cast_ext.py
|
||||||
extensions/front/onnx/clip_ext.py
|
extensions/front/onnx/clip_ext.py
|
||||||
|
extensions/front/onnx/concat_ext.py
|
||||||
extensions/front/onnx/const_ext.py
|
extensions/front/onnx/const_ext.py
|
||||||
extensions/front/onnx/constant_fill_ext.py
|
extensions/front/onnx/constant_fill_ext.py
|
||||||
extensions/front/onnx/constant_of_shape_ext.py
|
extensions/front/onnx/constant_of_shape_ext.py
|
||||||
@ -260,12 +261,14 @@ extensions/front/onnx/expand_ext.py
|
|||||||
extensions/front/onnx/faster_rcnn.json
|
extensions/front/onnx/faster_rcnn.json
|
||||||
extensions/front/onnx/flatten_ext.py
|
extensions/front/onnx/flatten_ext.py
|
||||||
extensions/front/onnx/flattenONNX_to_reshape.py
|
extensions/front/onnx/flattenONNX_to_reshape.py
|
||||||
|
extensions/front/onnx/fused_bn_ext.py
|
||||||
extensions/front/onnx/gather_ext.py
|
extensions/front/onnx/gather_ext.py
|
||||||
extensions/front/onnx/gathernd_ext.py
|
extensions/front/onnx/gathernd_ext.py
|
||||||
extensions/front/onnx/gemm_ext.py
|
extensions/front/onnx/gemm_ext.py
|
||||||
extensions/front/onnx/group_norm_ext.py
|
extensions/front/onnx/group_norm_ext.py
|
||||||
extensions/front/onnx/gru_ext.py
|
extensions/front/onnx/gru_ext.py
|
||||||
extensions/front/onnx/hard_sigmoid_ext.py
|
extensions/front/onnx/hard_sigmoid_ext.py
|
||||||
|
extensions/front/onnx/identity_ext.py
|
||||||
extensions/front/onnx/image_scaler_ext.py
|
extensions/front/onnx/image_scaler_ext.py
|
||||||
extensions/front/onnx/instance_normalization_ext.py
|
extensions/front/onnx/instance_normalization_ext.py
|
||||||
extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py
|
extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py
|
||||||
@ -301,6 +304,7 @@ extensions/front/onnx/quantize_linear_resolver.py
|
|||||||
extensions/front/onnx/range_ext.py
|
extensions/front/onnx/range_ext.py
|
||||||
extensions/front/onnx/reduce_ext.py
|
extensions/front/onnx/reduce_ext.py
|
||||||
extensions/front/onnx/remove_filtering_boxes_by_size.py
|
extensions/front/onnx/remove_filtering_boxes_by_size.py
|
||||||
|
extensions/front/onnx/reshape_ext.py
|
||||||
extensions/front/onnx/resize_ext.py
|
extensions/front/onnx/resize_ext.py
|
||||||
extensions/front/onnx/resize_to_interpolate.py
|
extensions/front/onnx/resize_to_interpolate.py
|
||||||
extensions/front/onnx/reverse_sequence_ext.py
|
extensions/front/onnx/reverse_sequence_ext.py
|
||||||
@ -602,9 +606,9 @@ extensions/ops/argmax.py
|
|||||||
extensions/ops/assert_op.py
|
extensions/ops/assert_op.py
|
||||||
extensions/ops/aten.py
|
extensions/ops/aten.py
|
||||||
extensions/ops/axpy.py
|
extensions/ops/axpy.py
|
||||||
|
extensions/ops/BatchNormInference.py
|
||||||
extensions/ops/binarization.py
|
extensions/ops/binarization.py
|
||||||
extensions/ops/BlockLSTM.py
|
extensions/ops/BlockLSTM.py
|
||||||
extensions/ops/bn.py
|
|
||||||
extensions/ops/box_nms.py
|
extensions/ops/box_nms.py
|
||||||
extensions/ops/bucketize.py
|
extensions/ops/bucketize.py
|
||||||
extensions/ops/Cast.py
|
extensions/ops/Cast.py
|
||||||
@ -843,10 +847,6 @@ mo/front/mxnet/register_custom_ops.py
|
|||||||
mo/front/onnx/__init__.py
|
mo/front/onnx/__init__.py
|
||||||
mo/front/onnx/extractor.py
|
mo/front/onnx/extractor.py
|
||||||
mo/front/onnx/extractors/__init__.py
|
mo/front/onnx/extractors/__init__.py
|
||||||
mo/front/onnx/extractors/concat.py
|
|
||||||
mo/front/onnx/extractors/eltwise.py
|
|
||||||
mo/front/onnx/extractors/fused_bn.py
|
|
||||||
mo/front/onnx/extractors/reshape.py
|
|
||||||
mo/front/onnx/extractors/utils.py
|
mo/front/onnx/extractors/utils.py
|
||||||
mo/front/onnx/loader.py
|
mo/front/onnx/loader.py
|
||||||
mo/front/onnx/register_custom_ops.py
|
mo/front/onnx/register_custom_ops.py
|
||||||
|
@ -27,7 +27,7 @@ class BNToScaleShift(FrontReplacementOp):
|
|||||||
"""
|
"""
|
||||||
Replaces BN layer with ScaleShift.
|
Replaces BN layer with ScaleShift.
|
||||||
"""
|
"""
|
||||||
op = "BN"
|
op = "batchNormInference"
|
||||||
enabled = True
|
enabled = True
|
||||||
|
|
||||||
def replace_op(self, graph: Graph, node: Node):
|
def replace_op(self, graph: Graph, node: Node):
|
||||||
|
@ -47,7 +47,7 @@ class TestBNReplacer(unittest.TestCase):
|
|||||||
FakeParam('data', shift)])
|
FakeParam('data', shift)])
|
||||||
nodes = [
|
nodes = [
|
||||||
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
||||||
('bn', {'type': 'BN', 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}),
|
('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}),
|
||||||
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
||||||
]
|
]
|
||||||
edges = [
|
edges = [
|
||||||
|
@ -14,21 +14,18 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from mo.front.common.partial_infer.eltwise import eltwise_infer
|
from mo.front.onnx.extractors.utils import onnx_attr
|
||||||
|
from mo.front.extractor import FrontExtractorOp
|
||||||
|
from mo.ops.concat import Concat
|
||||||
|
|
||||||
|
class ConcatFrontExtractor(FrontExtractorOp):
|
||||||
|
op = 'Concat'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
def tf_eltwise_ext(pb, op=None, attrs=None):
|
@classmethod
|
||||||
"""
|
def extract(cls, node):
|
||||||
Generic eltwise extractor that supports n-ary operations.
|
mapping_rule = {
|
||||||
It supports reasonable broadcast semantics from TF/NumPy
|
'axis': onnx_attr(node, 'axis', 'i', default=0)
|
||||||
"""
|
|
||||||
res = {
|
|
||||||
'infer': lambda node: eltwise_infer(node, op)
|
|
||||||
}
|
}
|
||||||
if attrs is not None:
|
Concat.update_node_stat(node, mapping_rule)
|
||||||
res.update(attrs)
|
return cls.enabled
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def make_tf_eltwise(op, attrs=None):
|
|
||||||
return lambda node: tf_eltwise_ext(node, op, attrs)
|
|
@ -15,19 +15,21 @@
|
|||||||
"""
|
"""
|
||||||
import logging as log
|
import logging as log
|
||||||
|
|
||||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
from extensions.ops.BatchNormInference import BatchNormInference
|
||||||
|
from mo.front.extractor import FrontExtractorOp
|
||||||
from mo.front.onnx.extractors.utils import onnx_attr
|
from mo.front.onnx.extractors.utils import onnx_attr
|
||||||
|
|
||||||
|
|
||||||
def tf_fused_bn_extractor(node):
|
|
||||||
pb = node.pb
|
|
||||||
# This statement covers different opset versions
|
|
||||||
if onnx_attr(node, 'is_test', 'i', None) == 0:
|
|
||||||
log.error('FusedBatchNorm doesn\'t support is_test=False')
|
|
||||||
return None
|
|
||||||
|
|
||||||
return {
|
class BatchNormalizationExtractor(FrontExtractorOp):
|
||||||
|
op = 'BatchNormalization'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract(cls, node):
|
||||||
|
attr_dict = {
|
||||||
'data_format': 'NCHW',
|
'data_format': 'NCHW',
|
||||||
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5),
|
'eps': onnx_attr(node, 'epsilon', 'f', 1e-5),
|
||||||
'infer': copy_shape_infer,
|
|
||||||
}
|
}
|
||||||
|
BatchNormInference.update_node_stat(node, attr_dict)
|
||||||
|
return cls.enabled
|
@ -13,14 +13,15 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
from mo.front.extractor import FrontExtractorOp
|
||||||
from mo.front.common.partial_infer.concat import concat_infer
|
from extensions.ops.identity import Identity
|
||||||
from mo.front.onnx.extractors.utils import onnx_attr
|
|
||||||
|
|
||||||
|
|
||||||
def concat_ext(node):
|
class IdentityFrontExtractor(FrontExtractorOp):
|
||||||
return {
|
op = 'Identity'
|
||||||
'type': "Concat",
|
enabled = True
|
||||||
'axis': onnx_attr(node, 'axis', 'i', default=0),
|
|
||||||
'infer': concat_infer
|
@classmethod
|
||||||
}
|
def extract(cls, node):
|
||||||
|
Identity.update_node_stat(node)
|
||||||
|
return cls.enabled
|
@ -16,19 +16,20 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from mo.front.extractor import FrontExtractorOp
|
||||||
from mo.front.onnx.extractors.utils import onnx_attr
|
from mo.front.onnx.extractors.utils import onnx_attr
|
||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
|
|
||||||
|
class ReshapeFrontExtractor(FrontExtractorOp):
|
||||||
|
op = 'Reshape'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
def onnx_reshape_ext(node):
|
@classmethod
|
||||||
''' Extract ONNX Reshape op of different versions.
|
def extract(cls, node):
|
||||||
Support both latest Reshape and Reshape-1.
|
|
||||||
The first one has 2 arguments, Reshape-1 has one input and shape is coded in attribute.
|
|
||||||
'''
|
|
||||||
dim = onnx_attr(node, 'shape', 'ints', None)
|
dim = onnx_attr(node, 'shape', 'ints', None)
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
dim = np.array(dim, dtype=np.int64)
|
dim = np.array(dim, dtype=np.int64)
|
||||||
Reshape.update_node_stat(node, {'dim': dim})
|
Reshape.update_node_stat(node, {'dim': dim})
|
||||||
else:
|
else:
|
||||||
Reshape.update_node_stat(node)
|
Reshape.update_node_stat(node)
|
||||||
return node.graph.node[node.id]
|
return cls.enabled
|
@ -61,7 +61,7 @@ class Fusing(MiddleReplacementPattern):
|
|||||||
for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))
|
for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))
|
||||||
|
|
||||||
# Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
|
# Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
|
||||||
# IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
|
# IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift
|
||||||
for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm)
|
for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm)
|
||||||
|
|
||||||
if fw == 'caffe':
|
if fw == 'caffe':
|
||||||
|
@ -18,18 +18,22 @@ from mo.graph.graph import Graph
|
|||||||
from mo.ops.op import Op
|
from mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
class BNOp(Op):
|
class BatchNormInference(Op):
|
||||||
"""
|
"""
|
||||||
Empty Op for BN layer. It will be replaced by BNToScaleShift FrontReplacer
|
BatchNormInference will be replaced by BNToScaleShift FrontReplacer for Caffe or convert_batch_norm
|
||||||
|
function for other frameworks
|
||||||
"""
|
"""
|
||||||
op = 'BN'
|
op = 'batchNormInference'
|
||||||
enabled = True
|
enabled = False
|
||||||
|
|
||||||
def __init__(self, graph: Graph, attrs: dict):
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
super().__init__(graph, {
|
super().__init__(graph, {
|
||||||
'type': None,
|
'type': None,
|
||||||
'op': __class__.op,
|
'op': self.op,
|
||||||
'in_ports_count': 5,
|
'in_ports_count': 5,
|
||||||
'out_ports_count': 1,
|
'out_ports_count': 1,
|
||||||
'infer': None
|
'infer': self.infer
|
||||||
}, attrs)
|
}, attrs)
|
||||||
|
@staticmethod
|
||||||
|
def infer(node):
|
||||||
|
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
@ -14,11 +14,6 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from mo.front.onnx.extractors.concat import concat_ext
|
|
||||||
from mo.front.onnx.extractors.eltwise import make_tf_eltwise
|
|
||||||
from mo.front.onnx.extractors.fused_bn import tf_fused_bn_extractor
|
|
||||||
from mo.front.onnx.extractors.reshape import onnx_reshape_ext
|
|
||||||
from mo.graph.graph import Node
|
from mo.graph.graph import Node
|
||||||
|
|
||||||
|
|
||||||
@ -26,12 +21,7 @@ def node_pb_arg(pb_extractor: callable):
|
|||||||
return lambda node: pb_extractor(node.pb)
|
return lambda node: pb_extractor(node.pb)
|
||||||
|
|
||||||
|
|
||||||
onnx_op_extractors = {
|
onnx_op_extractors = {}
|
||||||
'BatchNormalization': tf_fused_bn_extractor,
|
|
||||||
'Concat': concat_ext,
|
|
||||||
'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
|
|
||||||
'Reshape': onnx_reshape_ext,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def common_onnx_fields(node: Node):
|
def common_onnx_fields(node: Node):
|
||||||
|
@ -40,7 +40,7 @@ def convert_batch_norm(graph: Graph):
|
|||||||
nodes = graph.get_op_nodes()
|
nodes = graph.get_op_nodes()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.has_valid('op') and (node.op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
|
if node.has_valid('op') and (node.op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
|
||||||
'BatchNorm', 'BatchNormalization']):
|
'BatchNorm', 'BatchNormalization', 'batchNormInference']):
|
||||||
|
|
||||||
if any([node.in_port(i).data.get_value() is None for i in range(1, len(node.in_ports()))]):
|
if any([node.in_port(i).data.get_value() is None for i in range(1, len(node.in_ports()))]):
|
||||||
log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(
|
log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(
|
||||||
|
Loading…
Reference in New Issue
Block a user