diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 3fcebe98abd..34a3f8ffbd4 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -241,6 +241,7 @@ extensions/front/onnx/aten_ext.py extensions/front/onnx/AttributedSliceToSlice.py extensions/front/onnx/cast_ext.py extensions/front/onnx/clip_ext.py +extensions/front/onnx/concat_ext.py extensions/front/onnx/const_ext.py extensions/front/onnx/constant_fill_ext.py extensions/front/onnx/constant_of_shape_ext.py @@ -260,12 +261,14 @@ extensions/front/onnx/expand_ext.py extensions/front/onnx/faster_rcnn.json extensions/front/onnx/flatten_ext.py extensions/front/onnx/flattenONNX_to_reshape.py +extensions/front/onnx/fused_bn_ext.py extensions/front/onnx/gather_ext.py extensions/front/onnx/gathernd_ext.py extensions/front/onnx/gemm_ext.py extensions/front/onnx/group_norm_ext.py extensions/front/onnx/gru_ext.py extensions/front/onnx/hard_sigmoid_ext.py +extensions/front/onnx/identity_ext.py extensions/front/onnx/image_scaler_ext.py extensions/front/onnx/instance_normalization_ext.py extensions/front/onnx/logsoftmaxONNX_to_logsoftmax.py @@ -301,6 +304,7 @@ extensions/front/onnx/quantize_linear_resolver.py extensions/front/onnx/range_ext.py extensions/front/onnx/reduce_ext.py extensions/front/onnx/remove_filtering_boxes_by_size.py +extensions/front/onnx/reshape_ext.py extensions/front/onnx/resize_ext.py extensions/front/onnx/resize_to_interpolate.py extensions/front/onnx/reverse_sequence_ext.py @@ -602,9 +606,9 @@ extensions/ops/argmax.py extensions/ops/assert_op.py extensions/ops/aten.py extensions/ops/axpy.py +extensions/ops/BatchNormInference.py extensions/ops/binarization.py extensions/ops/BlockLSTM.py -extensions/ops/bn.py extensions/ops/box_nms.py extensions/ops/bucketize.py extensions/ops/Cast.py @@ -843,10 +847,6 @@ mo/front/mxnet/register_custom_ops.py mo/front/onnx/__init__.py mo/front/onnx/extractor.py mo/front/onnx/extractors/__init__.py -mo/front/onnx/extractors/concat.py -mo/front/onnx/extractors/eltwise.py -mo/front/onnx/extractors/fused_bn.py -mo/front/onnx/extractors/reshape.py mo/front/onnx/extractors/utils.py mo/front/onnx/loader.py mo/front/onnx/register_custom_ops.py @@ -1011,4 +1011,4 @@ requirements_kaldi.txt requirements_mxnet.txt requirements_onnx.txt requirements_tf.txt -requirements_tf2.txt +requirements_tf2.txt \ No newline at end of file diff --git a/model-optimizer/extensions/front/caffe/bn.py b/model-optimizer/extensions/front/caffe/bn.py index 4f7b33a39b2..3ad77c441c5 100644 --- a/model-optimizer/extensions/front/caffe/bn.py +++ b/model-optimizer/extensions/front/caffe/bn.py @@ -27,7 +27,7 @@ class BNToScaleShift(FrontReplacementOp): """ Replaces BN layer with ScaleShift. """ - op = "BN" + op = "batchNormInference" enabled = True def replace_op(self, graph: Graph, node: Node): diff --git a/model-optimizer/extensions/front/caffe/bn_test.py b/model-optimizer/extensions/front/caffe/bn_test.py index fd899d8ca63..37c0c3fa5b5 100644 --- a/model-optimizer/extensions/front/caffe/bn_test.py +++ b/model-optimizer/extensions/front/caffe/bn_test.py @@ -47,7 +47,7 @@ class TestBNReplacer(unittest.TestCase): FakeParam('data', shift)]) nodes = [ ('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}), - ('bn', {'type': 'BN', 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}), + ('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}), ('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}), ] edges = [ diff --git a/model-optimizer/mo/front/onnx/extractors/eltwise.py b/model-optimizer/extensions/front/onnx/concat_ext.py similarity index 54% rename from model-optimizer/mo/front/onnx/extractors/eltwise.py rename to model-optimizer/extensions/front/onnx/concat_ext.py index 55df7acec61..287de77e082 100644 --- a/model-optimizer/mo/front/onnx/extractors/eltwise.py +++ b/model-optimizer/extensions/front/onnx/concat_ext.py @@ -14,21 +14,18 @@ limitations under the License. """ -from mo.front.common.partial_infer.eltwise import eltwise_infer +from mo.front.onnx.extractors.utils import onnx_attr +from mo.front.extractor import FrontExtractorOp +from mo.ops.concat import Concat + +class ConcatFrontExtractor(FrontExtractorOp): + op = 'Concat' + enabled = True - -def tf_eltwise_ext(pb, op=None, attrs=None): - """ - Generic eltwise extractor that supports n-ary operations. - It supports reasonable broadcast semantics from TF/NumPy - """ - res = { - 'infer': lambda node: eltwise_infer(node, op) - } - if attrs is not None: - res.update(attrs) - return res - - -def make_tf_eltwise(op, attrs=None): - return lambda node: tf_eltwise_ext(node, op, attrs) + @classmethod + def extract(cls, node): + mapping_rule = { + 'axis': onnx_attr(node, 'axis', 'i', default=0) + } + Concat.update_node_stat(node, mapping_rule) + return cls.enabled diff --git a/model-optimizer/mo/front/onnx/extractors/fused_bn.py b/model-optimizer/extensions/front/onnx/fused_bn_ext.py similarity index 58% rename from model-optimizer/mo/front/onnx/extractors/fused_bn.py rename to model-optimizer/extensions/front/onnx/fused_bn_ext.py index d4e0a3c788d..7493210d53f 100644 --- a/model-optimizer/mo/front/onnx/extractors/fused_bn.py +++ b/model-optimizer/extensions/front/onnx/fused_bn_ext.py @@ -15,19 +15,21 @@ """ import logging as log -from mo.front.common.partial_infer.elemental import copy_shape_infer +from extensions.ops.BatchNormInference import BatchNormInference +from mo.front.extractor import FrontExtractorOp from mo.front.onnx.extractors.utils import onnx_attr -def tf_fused_bn_extractor(node): - pb = node.pb - # This statement covers different opset versions - if onnx_attr(node, 'is_test', 'i', None) == 0: - log.error('FusedBatchNorm doesn\'t support is_test=False') - return None - return { - 'data_format': 'NCHW', - 'eps': onnx_attr(node, 'epsilon', 'f', 1e-5), - 'infer': copy_shape_infer, - } +class BatchNormalizationExtractor(FrontExtractorOp): + op = 'BatchNormalization' + enabled = True + + @classmethod + def extract(cls, node): + attr_dict = { + 'data_format': 'NCHW', + 'eps': onnx_attr(node, 'epsilon', 'f', 1e-5), + } + BatchNormInference.update_node_stat(node, attr_dict) + return cls.enabled diff --git a/model-optimizer/mo/front/onnx/extractors/concat.py b/model-optimizer/extensions/front/onnx/identity_ext.py similarity index 66% rename from model-optimizer/mo/front/onnx/extractors/concat.py rename to model-optimizer/extensions/front/onnx/identity_ext.py index d5c5e894dba..5ec3eb3f866 100644 --- a/model-optimizer/mo/front/onnx/extractors/concat.py +++ b/model-optimizer/extensions/front/onnx/identity_ext.py @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from mo.front.common.partial_infer.concat import concat_infer -from mo.front.onnx.extractors.utils import onnx_attr +from mo.front.extractor import FrontExtractorOp +from extensions.ops.identity import Identity -def concat_ext(node): - return { - 'type': "Concat", - 'axis': onnx_attr(node, 'axis', 'i', default=0), - 'infer': concat_infer - } +class IdentityFrontExtractor(FrontExtractorOp): + op = 'Identity' + enabled = True + + @classmethod + def extract(cls, node): + Identity.update_node_stat(node) + return cls.enabled diff --git a/model-optimizer/mo/front/onnx/extractors/reshape.py b/model-optimizer/extensions/front/onnx/reshape_ext.py similarity index 58% rename from model-optimizer/mo/front/onnx/extractors/reshape.py rename to model-optimizer/extensions/front/onnx/reshape_ext.py index 089566cc858..9c79f9d9783 100644 --- a/model-optimizer/mo/front/onnx/extractors/reshape.py +++ b/model-optimizer/extensions/front/onnx/reshape_ext.py @@ -16,19 +16,20 @@ import numpy as np +from mo.front.extractor import FrontExtractorOp from mo.front.onnx.extractors.utils import onnx_attr from mo.ops.reshape import Reshape +class ReshapeFrontExtractor(FrontExtractorOp): + op = 'Reshape' + enabled = True -def onnx_reshape_ext(node): - ''' Extract ONNX Reshape op of different versions. - Support both latest Reshape and Reshape-1. - The first one has 2 arguments, Reshape-1 has one input and shape is coded in attribute. - ''' - dim = onnx_attr(node, 'shape', 'ints', None) - if dim is not None: - dim = np.array(dim, dtype=np.int64) - Reshape.update_node_stat(node, {'dim': dim}) - else: - Reshape.update_node_stat(node) - return node.graph.node[node.id] + @classmethod + def extract(cls, node): + dim = onnx_attr(node, 'shape', 'ints', None) + if dim is not None: + dim = np.array(dim, dtype=np.int64) + Reshape.update_node_stat(node, {'dim': dim}) + else: + Reshape.update_node_stat(node) + return cls.enabled diff --git a/model-optimizer/extensions/middle/fusings.py b/model-optimizer/extensions/middle/fusings.py index 40ce7b2bda7..55abb15499a 100644 --- a/model-optimizer/extensions/middle/fusings.py +++ b/model-optimizer/extensions/middle/fusings.py @@ -61,7 +61,7 @@ class Fusing(MiddleReplacementPattern): for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing)) # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence - # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift + # IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm) if fw == 'caffe': diff --git a/model-optimizer/extensions/ops/bn.py b/model-optimizer/extensions/ops/BatchNormInference.py similarity index 68% rename from model-optimizer/extensions/ops/bn.py rename to model-optimizer/extensions/ops/BatchNormInference.py index ae3edc02bc5..307a2ed9af3 100644 --- a/model-optimizer/extensions/ops/bn.py +++ b/model-optimizer/extensions/ops/BatchNormInference.py @@ -18,18 +18,22 @@ from mo.graph.graph import Graph from mo.ops.op import Op -class BNOp(Op): +class BatchNormInference(Op): """ - Empty Op for BN layer. It will be replaced by BNToScaleShift FrontReplacer + BatchNormInference will be replaced by BNToScaleShift FrontReplacer for Caffe or convert_batch_norm + function for other frameworks """ - op = 'BN' - enabled = True + op = 'batchNormInference' + enabled = False def __init__(self, graph: Graph, attrs: dict): super().__init__(graph, { 'type': None, - 'op': __class__.op, + 'op': self.op, 'in_ports_count': 5, 'out_ports_count': 1, - 'infer': None + 'infer': self.infer }, attrs) + @staticmethod + def infer(node): + node.out_port(0).data.set_shape(node.in_port(0).data.get_shape()) diff --git a/model-optimizer/mo/front/onnx/extractor.py b/model-optimizer/mo/front/onnx/extractor.py index 05e524de046..48abd3e0e12 100644 --- a/model-optimizer/mo/front/onnx/extractor.py +++ b/model-optimizer/mo/front/onnx/extractor.py @@ -14,11 +14,6 @@ limitations under the License. """ - -from mo.front.onnx.extractors.concat import concat_ext -from mo.front.onnx.extractors.eltwise import make_tf_eltwise -from mo.front.onnx.extractors.fused_bn import tf_fused_bn_extractor -from mo.front.onnx.extractors.reshape import onnx_reshape_ext from mo.graph.graph import Node @@ -26,12 +21,7 @@ def node_pb_arg(pb_extractor: callable): return lambda node: pb_extractor(node.pb) -onnx_op_extractors = { - 'BatchNormalization': tf_fused_bn_extractor, - 'Concat': concat_ext, - 'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})), - 'Reshape': onnx_reshape_ext, -} +onnx_op_extractors = {} def common_onnx_fields(node: Node): diff --git a/model-optimizer/mo/middle/passes/fusing/decomposition.py b/model-optimizer/mo/middle/passes/fusing/decomposition.py index 365b2d4954f..36ac103b9a0 100644 --- a/model-optimizer/mo/middle/passes/fusing/decomposition.py +++ b/model-optimizer/mo/middle/passes/fusing/decomposition.py @@ -40,7 +40,7 @@ def convert_batch_norm(graph: Graph): nodes = graph.get_op_nodes() for node in nodes: if node.has_valid('op') and (node.op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3', - 'BatchNorm', 'BatchNormalization']): + 'BatchNorm', 'BatchNormalization', 'batchNormInference']): if any([node.in_port(i).data.get_value() is None for i in range(1, len(node.in_ports()))]): log.warning('Cannot translate FusedBatchNorm {} node with non-constant weights'.format(