Re-implement caffe old-style extractors with extractor extensions (#3675)
* move crop extractor * Add concat_ext.py * Add roipooling_ext.py * Add roipooling_ext * Add scale extractor * Add scale extractor * Add bn_ext.py and dropout_ext.py * Add bn_ext.py and dropout_ext.py * Add bn_ext.py and dropout_ext.py * Fix bn.ext.py * Sort fix * Fix bn_test.py * rename to batchnorm_ext * Add bn_ext * Fix batchnorm_ext.py * small fix * Small fix
This commit is contained in:
parent
a6a5635a59
commit
1a787cb3ba
@ -71,15 +71,20 @@ extensions/front/caffe/accum_ext.py
|
||||
extensions/front/caffe/argmax_ext.py
|
||||
extensions/front/caffe/ArgMaxFlatten.py
|
||||
extensions/front/caffe/axpy.py
|
||||
extensions/front/caffe/batchnorm_ext.py
|
||||
extensions/front/caffe/binarization.py
|
||||
extensions/front/caffe/binary_conv_ext.py
|
||||
extensions/front/caffe/bn.py
|
||||
extensions/front/caffe/bn_ext.py
|
||||
extensions/front/caffe/concat_ext.py
|
||||
extensions/front/caffe/conv_ext.py
|
||||
extensions/front/caffe/correlation_ext.py
|
||||
extensions/front/caffe/crop_ext.py
|
||||
extensions/front/caffe/ctcgreedydecoder_ext.py
|
||||
extensions/front/caffe/CustomLayersMapping.xml.example
|
||||
extensions/front/caffe/data_augmentation_ext.py
|
||||
extensions/front/caffe/detection_output.py
|
||||
extensions/front/caffe/dropout_ext.py
|
||||
extensions/front/caffe/elementwise_ext.py
|
||||
extensions/front/caffe/eltwise_add_normalize.py
|
||||
extensions/front/caffe/elu.py
|
||||
@ -106,6 +111,8 @@ extensions/front/caffe/relu_ext.py
|
||||
extensions/front/caffe/reorgyolo_ext.py
|
||||
extensions/front/caffe/resample_ext.py
|
||||
extensions/front/caffe/reshape.py
|
||||
extensions/front/caffe/roipooling_ext.py
|
||||
extensions/front/caffe/scale_ext.py
|
||||
extensions/front/caffe/shufflechannel_ext.py
|
||||
extensions/front/caffe/sigmoid.py
|
||||
extensions/front/caffe/simplernms_ext.py
|
||||
@ -618,6 +625,7 @@ extensions/ops/axpy.py
|
||||
extensions/ops/BatchNormInference.py
|
||||
extensions/ops/binarization.py
|
||||
extensions/ops/BlockLSTM.py
|
||||
extensions/ops/BN.py
|
||||
extensions/ops/box_nms.py
|
||||
extensions/ops/bucketize.py
|
||||
extensions/ops/Cast.py
|
||||
@ -758,12 +766,7 @@ mo/front/caffe/collect_attributes.py
|
||||
mo/front/caffe/custom_layers_mapping.py
|
||||
mo/front/caffe/extractor.py
|
||||
mo/front/caffe/extractors/__init__.py
|
||||
mo/front/caffe/extractors/batchnorm.py
|
||||
mo/front/caffe/extractors/concat.py
|
||||
mo/front/caffe/extractors/crop.py
|
||||
mo/front/caffe/extractors/native_caffe.py
|
||||
mo/front/caffe/extractors/roipooling.py
|
||||
mo/front/caffe/extractors/scale.py
|
||||
mo/front/caffe/extractors/tile.py
|
||||
mo/front/caffe/extractors/utils.py
|
||||
mo/front/caffe/loader.py
|
||||
|
53
model-optimizer/extensions/front/caffe/batchnorm_ext.py
Normal file
53
model-optimizer/extensions/front/caffe/batchnorm_ext.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.BatchNormInference import BatchNormInference
|
||||
from mo.front.caffe.extractors.utils import embed_input
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
class BatchNormalizationExtractor(FrontExtractorOp):
|
||||
op = 'batchnorm'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
eps = node.pb.batch_norm_param.eps
|
||||
attrs = {
|
||||
'eps': eps
|
||||
}
|
||||
pb_model = None if not node.soft_get('model_pb', None) else node.model_pb
|
||||
if pb_model:
|
||||
blobs = pb_model.blobs
|
||||
assert len(blobs) >= 2, 'BatchNorm accepts not less then two input blobs'
|
||||
mean = np.array(blobs[0].data)
|
||||
variance = np.array(blobs[1].data)
|
||||
|
||||
if len(blobs) == 3:
|
||||
scale = blobs[2].data[0]
|
||||
if scale != 0:
|
||||
scale = 1.0 / scale
|
||||
mean *= scale
|
||||
variance *= scale
|
||||
|
||||
embed_input(attrs, 1, 'gamma', np.ones(mean.shape), 'gamma')
|
||||
embed_input(attrs, 2, 'beta', np.zeros(variance.shape), 'beta')
|
||||
embed_input(attrs, 3, 'mean', mean, 'biases')
|
||||
embed_input(attrs, 4, 'variance', variance, 'weights')
|
||||
|
||||
BatchNormInference.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -27,7 +27,7 @@ class BNToScaleShift(FrontReplacementOp):
|
||||
"""
|
||||
Replaces BN layer with ScaleShift.
|
||||
"""
|
||||
op = "batchNormInference"
|
||||
op = "BN"
|
||||
enabled = True
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
@ -35,6 +35,7 @@ class BNToScaleShift(FrontReplacementOp):
|
||||
|
||||
param = graph.node[node.id]['pb'].bn_param
|
||||
pb_model = graph.node[node.id]['model_pb']
|
||||
|
||||
blobs = pb_model.blobs
|
||||
|
||||
if len(blobs) != 4:
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,12 +14,15 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.concat import concat_infer
|
||||
from extensions.ops.BN import BN
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
def concat_ext(pb_layer, pb_model):
|
||||
return {
|
||||
'type': "Concat",
|
||||
'axis': pb_layer.concat_param.axis,
|
||||
'infer': concat_infer
|
||||
}
|
||||
class BNExtractor(FrontExtractorOp):
|
||||
op = 'BN'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
BN.update_node_stat(node, {})
|
||||
return cls.enabled
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,9 +13,8 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from extensions.front.caffe.bn import BNToScaleShift
|
||||
from mo.graph.graph import Node
|
||||
@ -47,7 +46,7 @@ class TestBNReplacer(unittest.TestCase):
|
||||
FakeParam('data', shift)])
|
||||
nodes = [
|
||||
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
||||
('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}),
|
||||
('bn', {'type': None, 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}),
|
||||
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
|
||||
]
|
||||
edges = [
|
||||
|
32
model-optimizer/extensions/front/caffe/concat_ext.py
Normal file
32
model-optimizer/extensions/front/caffe/concat_ext.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.ops.concat import Concat
|
||||
|
||||
|
||||
class ConcatFrontExtractor(FrontExtractorOp):
|
||||
op = 'concat'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
pb = node.pb
|
||||
mapping_rule = {
|
||||
'axis': pb.concat_param.axis,
|
||||
}
|
||||
Concat.update_node_stat(node, mapping_rule)
|
||||
return cls.enabled
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from mo.front.caffe.extractors.crop import CropFrontExtractor
|
||||
from extensions.front.caffe.crop_ext import CropFrontExtractor
|
||||
from mo.front.common.partial_infer.crop import crop_infer
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.op import Op
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,15 +14,16 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.partial_infer.roipooling import roipooling_infer
|
||||
from extensions.ops.identity import Identity
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.graph.graph import Node
|
||||
|
||||
|
||||
def roipooling_ext(proto_layer, model_layer):
|
||||
param = proto_layer.roi_pooling_param
|
||||
return {
|
||||
'type': 'ROIPooling',
|
||||
'pooled_h': param.pooled_h,
|
||||
'pooled_w': param.pooled_w,
|
||||
'spatial_scale': param.spatial_scale,
|
||||
'infer': roipooling_infer
|
||||
}
|
||||
class DropoutFrontExtractor(FrontExtractorOp):
|
||||
op = 'dropout'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
Identity.update_node_stat(node, {})
|
||||
return cls.enabled
|
35
model-optimizer/extensions/front/caffe/roipooling_ext.py
Normal file
35
model-optimizer/extensions/front/caffe/roipooling_ext.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.ops.roipooling import ROIPooling
|
||||
|
||||
|
||||
class ROIPoolingFrontExtractor(FrontExtractorOp):
|
||||
op = 'roipooling'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
param = node.pb.roi_pooling_param
|
||||
attrs = {
|
||||
'pooled_h': param.pooled_h,
|
||||
'pooled_w': param.pooled_w,
|
||||
'spatial_scale': param.spatial_scale,
|
||||
}
|
||||
|
||||
ROIPooling.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
55
model-optimizer/extensions/front/caffe/scale_ext.py
Normal file
55
model-optimizer/extensions/front/caffe/scale_ext.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import embed_input, weights_biases
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
from mo.utils.utils import NamedAttrsClass
|
||||
|
||||
|
||||
class ScaleFrontExtractor(FrontExtractorOp):
|
||||
op = 'scale'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
pb = node.pb
|
||||
model = node.model_pb
|
||||
param = pb.scale_param
|
||||
attrs = {
|
||||
'axis': param.axis,
|
||||
}
|
||||
|
||||
if model is None and len(pb.bottom) == 1:
|
||||
# default weights and biases for scale layer if the caffemodel file doesn't contain them
|
||||
model = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([1])}),
|
||||
NamedAttrsClass({'data': np.array([0])})])})
|
||||
# scale with 1 input and 1 or 2 blobs
|
||||
if model and len(model.blobs) != 0 and len(pb.bottom) == 1:
|
||||
attrs.update(weights_biases(param.bias_term, model))
|
||||
# 2 inputs + bias
|
||||
elif len(pb.bottom) == 2 and param.bias_term:
|
||||
if model is None or len(model.blobs) == 0:
|
||||
# default bias for scale layer with 2 inputs if the caffemodel file doesn't contain them
|
||||
model = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([0])})])})
|
||||
|
||||
embed_input(attrs, 1, 'biases', model.blobs[0].data)
|
||||
ScaleShiftOp.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
|
35
model-optimizer/extensions/ops/BN.py
Normal file
35
model-optimizer/extensions/ops/BN.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class BN(Op):
|
||||
"""
|
||||
BN operation comes from caffe and will be replaced by BNToScaleShift FrontReplacer.
|
||||
"""
|
||||
op = 'BN'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'in_ports_count': 5,
|
||||
'out_ports_count': 1,
|
||||
'infer': None
|
||||
}, attrs)
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,11 +14,7 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.caffe.extractors.batchnorm import batch_norm_ext
|
||||
from mo.front.caffe.extractors.concat import concat_ext
|
||||
from mo.front.caffe.extractors.native_caffe import native_caffe_node_extractor
|
||||
from mo.front.caffe.extractors.roipooling import roipooling_ext
|
||||
from mo.front.caffe.extractors.scale import scale_ext
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.front.common.register_custom_ops import extension_op_extractor
|
||||
from mo.front.extractor import CaffePythonFrontExtractorOp
|
||||
@ -36,22 +32,8 @@ def node_pb_arg(pb_extractor):
|
||||
Keys are names that appear as layer names in .prototxt.
|
||||
Full list is available here: http://caffe.berkeleyvision.org/tutorial/layers.html
|
||||
"""
|
||||
caffe_type_extractors = {
|
||||
# Common Layers
|
||||
'dropout': node_pb_arg(lambda _, __: dict(op='Dropout', infer=copy_shape_infer)),
|
||||
|
||||
# Normalization Layers
|
||||
'batchnorm': node_pb_arg(batch_norm_ext),
|
||||
|
||||
# Activation Layers
|
||||
'scale': node_pb_arg(scale_ext),
|
||||
|
||||
# Utility Layers
|
||||
'concat': node_pb_arg(concat_ext),
|
||||
|
||||
# Custom, implemented in IE, Fast-RCNN-specific
|
||||
'roipooling': node_pb_arg(roipooling_ext),
|
||||
}
|
||||
caffe_type_extractors = {}
|
||||
|
||||
|
||||
def common_caffe_fields(node: Node) -> dict:
|
||||
@ -62,6 +44,7 @@ def common_caffe_fields(node: Node) -> dict:
|
||||
if isinstance(layer_type, int):
|
||||
layer_type = pb.LayerType.DESCRIPTOR.values_by_number[layer_type].name
|
||||
layer_type = str(layer_type)
|
||||
|
||||
return {
|
||||
'kind': 'op',
|
||||
'name': pb.name,
|
||||
|
@ -1,63 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import embed_input
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
|
||||
|
||||
def batch_norm_ext(pb_layer, pb_model):
|
||||
"""
|
||||
Extracts properties of the BatchNorm layer.
|
||||
In case of scale, scale is merged into mean and variance
|
||||
Args:
|
||||
pb_layer: proto layer, contains own properties of the layer, i.e epsilon
|
||||
pb_model: caffemodel layer, contains blobs with 0: mean, 1: variance, (opt)2: scale
|
||||
|
||||
Returns:
|
||||
attrs object with type, partial inference function and mean/variance properties.
|
||||
"""
|
||||
assert pb_layer, 'Protobuf layer can not be empty'
|
||||
param = pb_layer.batch_norm_param
|
||||
attrs = {
|
||||
'op': 'BatchNormalization',
|
||||
'type': 'BatchNormalization',
|
||||
'eps': param.eps,
|
||||
'infer': copy_shape_infer
|
||||
}
|
||||
|
||||
if not pb_model:
|
||||
return attrs
|
||||
|
||||
blobs = pb_model.blobs
|
||||
assert len(blobs) >= 2, 'BatchNorm accepts not less then two input blobs'
|
||||
mean = np.array(blobs[0].data)
|
||||
variance = np.array(blobs[1].data)
|
||||
|
||||
if len(blobs) == 3:
|
||||
scale = blobs[2].data[0]
|
||||
if scale != 0:
|
||||
scale = 1.0 / scale
|
||||
mean *= scale
|
||||
variance *= scale
|
||||
|
||||
embed_input(attrs, 1, 'gamma', np.ones(mean.shape), 'gamma')
|
||||
embed_input(attrs, 2, 'beta', np.zeros(variance.shape), 'beta')
|
||||
embed_input(attrs, 3, 'mean', mean, 'biases')
|
||||
embed_input(attrs, 4, 'variance', variance, 'weights')
|
||||
|
||||
return attrs
|
@ -1,147 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.batchnorm import batch_norm_ext
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.utils.unittest.extractors import FakeParam, FakeModelLayer
|
||||
|
||||
|
||||
class FakeBNProtoLayer:
|
||||
def __init__(self, eps):
|
||||
self.batch_norm_param = FakeParam('eps', eps)
|
||||
|
||||
|
||||
class TestShapesParsing(unittest.TestCase):
|
||||
def test_bn_ext_no_ml_no_pb(self):
|
||||
self.assertRaises(AssertionError, batch_norm_ext, None, None)
|
||||
|
||||
def test_bn_ext_no_ml(self):
|
||||
res = batch_norm_ext(FakeBNProtoLayer(10), None)
|
||||
exp_res = {
|
||||
'op': 'BatchNormalization',
|
||||
'type': 'BatchNormalization',
|
||||
'eps': 10,
|
||||
'infer': copy_shape_infer
|
||||
}
|
||||
self.assertEqual(res, exp_res)
|
||||
|
||||
def test_bn_ext_ml_one_blob(self):
|
||||
self.assertRaises(AssertionError, batch_norm_ext, FakeBNProtoLayer(10), FakeModelLayer([np.array([1, 2])]))
|
||||
|
||||
def test_bn_ext_ml_two_blobs(self):
|
||||
mean_blob = np.array([1., 2.])
|
||||
variance_blob = np.array([3., 4.])
|
||||
blobs = [mean_blob, variance_blob]
|
||||
res = batch_norm_ext(FakeBNProtoLayer(10),
|
||||
FakeModelLayer(blobs))
|
||||
exp_res = {
|
||||
'type': 'BatchNormalization',
|
||||
'eps': 10,
|
||||
'infer': copy_shape_infer,
|
||||
'mean': mean_blob,
|
||||
'variance': variance_blob,
|
||||
'embedded_inputs': [
|
||||
(1, 'gamma', {
|
||||
'bin': 'gamma'
|
||||
}),
|
||||
(2, 'beta', {
|
||||
'bin': 'beta'
|
||||
}),
|
||||
(3, 'mean', {
|
||||
'bin': 'biases'
|
||||
}),
|
||||
(4, 'variance', {
|
||||
'bin': 'weights'
|
||||
})
|
||||
]
|
||||
}
|
||||
for i in exp_res:
|
||||
if i in ('mean', 'variance'):
|
||||
np.testing.assert_array_equal(res[i], exp_res[i])
|
||||
else:
|
||||
self.assertEqual(res[i], exp_res[i])
|
||||
|
||||
def test_bn_ext_ml_three_blobs(self):
|
||||
mean_blob = np.array([1., 2.])
|
||||
variance_blob = np.array([3., 4.])
|
||||
scale_blob = np.array([5., ])
|
||||
blobs = [mean_blob, variance_blob, scale_blob]
|
||||
res = batch_norm_ext(FakeBNProtoLayer(10),
|
||||
FakeModelLayer(blobs))
|
||||
exp_res = {
|
||||
'type': 'BatchNormalization',
|
||||
'eps': 10,
|
||||
'infer': copy_shape_infer,
|
||||
'mean': mean_blob * 0.2,
|
||||
'variance': variance_blob * 0.2,
|
||||
'embedded_inputs': [
|
||||
(1, 'gamma', {
|
||||
'bin': 'gamma'
|
||||
}),
|
||||
(2, 'beta', {
|
||||
'bin': 'beta'
|
||||
}),
|
||||
(3, 'mean', {
|
||||
'bin': 'biases'
|
||||
}),
|
||||
(4, 'variance', {
|
||||
'bin': 'weights'
|
||||
})
|
||||
]
|
||||
}
|
||||
for i in exp_res:
|
||||
if i in ('mean', 'variance'):
|
||||
np.testing.assert_array_equal(res[i], exp_res[i])
|
||||
else:
|
||||
self.assertEqual(res[i], exp_res[i])
|
||||
|
||||
def test_bn_ext_ml_three_blobs_zero_scale(self):
|
||||
mean_blob = np.array([1., 2.])
|
||||
variance_blob = np.array([3., 4.])
|
||||
scale_blob = np.array([0., ])
|
||||
blobs = [mean_blob, variance_blob, scale_blob]
|
||||
res = batch_norm_ext(FakeBNProtoLayer(10),
|
||||
FakeModelLayer(blobs))
|
||||
exp_res = {
|
||||
'type': 'BatchNormalization',
|
||||
'eps': 10,
|
||||
'infer': copy_shape_infer,
|
||||
'mean': mean_blob * 0.,
|
||||
'variance': variance_blob * 0.,
|
||||
'embedded_inputs': [
|
||||
(1, 'gamma', {
|
||||
'bin': 'gamma'
|
||||
}),
|
||||
(2, 'beta', {
|
||||
'bin': 'beta'
|
||||
}),
|
||||
(3, 'mean', {
|
||||
'bin': 'biases'
|
||||
}),
|
||||
(4, 'variance', {
|
||||
'bin': 'weights'
|
||||
})
|
||||
]
|
||||
}
|
||||
for i in exp_res:
|
||||
if i in ('mean', 'variance'):
|
||||
np.testing.assert_array_equal(res[i], exp_res[i])
|
||||
else:
|
||||
self.assertEqual(res[i], exp_res[i])
|
@ -1,37 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from mo.front.caffe.extractors.concat import concat_ext
|
||||
from mo.front.common.partial_infer.concat import concat_infer
|
||||
from mo.utils.unittest.extractors import FakeParam
|
||||
|
||||
|
||||
class FakeProtoLayer:
|
||||
def __init__(self, axis):
|
||||
self.concat_param = FakeParam('axis', axis)
|
||||
|
||||
|
||||
class TestConcat(unittest.TestCase):
|
||||
def test_concat(self):
|
||||
res = concat_ext(FakeProtoLayer(10), None)
|
||||
exp_res = {
|
||||
'axis': 10,
|
||||
'infer': concat_infer,
|
||||
'type': 'Concat'
|
||||
}
|
||||
self.assertEqual(res, exp_res)
|
@ -1,47 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import embed_input, weights_biases
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.utils.utils import NamedAttrsClass
|
||||
|
||||
|
||||
def scale_ext(pl, ml):
|
||||
param = pl.scale_param
|
||||
attrs = {
|
||||
'op': 'ScaleShift',
|
||||
'type': 'ScaleShift',
|
||||
'axis': param.axis,
|
||||
'infer': copy_shape_infer
|
||||
}
|
||||
if ml is None and len(pl.bottom) == 1:
|
||||
# default weights and biases for scale layer if the caffemodel file doesn't contain them
|
||||
ml = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([1])}),
|
||||
NamedAttrsClass({'data': np.array([0])})])})
|
||||
# scale with 1 input and 1 or 2 blobs
|
||||
if ml and len(ml.blobs) != 0 and len(pl.bottom) == 1:
|
||||
attrs.update(weights_biases(param.bias_term, ml))
|
||||
# 2 inputs + bias
|
||||
elif len(pl.bottom) == 2 and param.bias_term:
|
||||
if ml is None or len(ml.blobs) == 0:
|
||||
# default bias for scale layer with 2 inputs if the caffemodel file doesn't contain them
|
||||
ml = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([0])})])})
|
||||
|
||||
embed_input(attrs, 1, 'biases', ml.blobs[0].data)
|
||||
|
||||
return attrs
|
@ -1,144 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.scale import scale_ext
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.utils.unittest.extractors import FakeMultiParam, FakeModelLayer
|
||||
|
||||
|
||||
class FakeProtoLayer:
|
||||
def __init__(self, val, bottom2=False):
|
||||
self.scale_param = val
|
||||
if bottom2:
|
||||
self.bottom = {"bottom1", "bottom2"}
|
||||
else:
|
||||
self.bottom = {"bottom1"}
|
||||
|
||||
|
||||
class TestScale(unittest.TestCase):
|
||||
def test_scale_ext(self):
|
||||
mean_blob = np.array([1., 2.])
|
||||
variance_blob = np.array([3., 4.])
|
||||
blobs = [mean_blob, variance_blob]
|
||||
params = {
|
||||
'type': 'Scale',
|
||||
'axis': 0,
|
||||
'bias_term': True
|
||||
}
|
||||
|
||||
res = scale_ext(FakeProtoLayer(FakeMultiParam(params)), FakeModelLayer(blobs))
|
||||
exp_res = {
|
||||
'op': 'ScaleShift',
|
||||
'type': 'ScaleShift',
|
||||
'axis': 0,
|
||||
'infer': copy_shape_infer,
|
||||
'weights': mean_blob,
|
||||
'biases': variance_blob,
|
||||
'embedded_inputs': [
|
||||
(1, 'weights', {
|
||||
'bin': 'weights'
|
||||
}),
|
||||
(2, 'biases', {
|
||||
'bin': 'biases'
|
||||
})
|
||||
]
|
||||
}
|
||||
for i in exp_res:
|
||||
if i in ('weights', 'biases'):
|
||||
np.testing.assert_array_equal(res[i], exp_res[i])
|
||||
else:
|
||||
self.assertEqual(res[i], exp_res[i])
|
||||
|
||||
def test_scale_2inputs_ext(self):
|
||||
params = {
|
||||
'type': 'Scale',
|
||||
'axis': 0,
|
||||
'bias_term': False
|
||||
}
|
||||
|
||||
res = scale_ext(FakeProtoLayer(FakeMultiParam(params), True), None)
|
||||
exp_res = {
|
||||
'op': 'ScaleShift',
|
||||
'type': 'ScaleShift',
|
||||
'axis': 0,
|
||||
'infer': copy_shape_infer,
|
||||
}
|
||||
for i in exp_res:
|
||||
self.assertEqual(res[i], exp_res[i])
|
||||
|
||||
def test_scale_2inputs_bias_ext(self):
|
||||
variance_blob = np.array([3., 4.])
|
||||
blobs = [variance_blob]
|
||||
|
||||
params = {
|
||||
'type': 'Scale',
|
||||
'axis': 0,
|
||||
'bias_term': True
|
||||
}
|
||||
|
||||
res = scale_ext(FakeProtoLayer(FakeMultiParam(params), True), FakeModelLayer(blobs))
|
||||
exp_res = {
|
||||
'op': 'ScaleShift',
|
||||
'type': 'ScaleShift',
|
||||
'axis': 0,
|
||||
'infer': copy_shape_infer,
|
||||
'biases': variance_blob,
|
||||
'embedded_inputs': [
|
||||
(1, 'biases', {
|
||||
'bin': 'biases'
|
||||
})]
|
||||
}
|
||||
for i in exp_res:
|
||||
if i in ('biases'):
|
||||
np.testing.assert_array_equal(res[i], exp_res[i])
|
||||
else:
|
||||
self.assertEqual(res[i], exp_res[i])
|
||||
|
||||
def test_create_default_weights(self):
|
||||
"""
|
||||
There are situations when scale layer doesn't have weights and biases. This test checks that if they are not
|
||||
available in the caffemodel file then default values [1] and [0] are generated.
|
||||
"""
|
||||
scale_blob = np.array([1])
|
||||
bias_blob = np.array([0])
|
||||
params = {
|
||||
'type': 'Scale',
|
||||
'axis': 0,
|
||||
'bias_term': True
|
||||
}
|
||||
|
||||
res = scale_ext(FakeProtoLayer(FakeMultiParam(params)), None)
|
||||
exp_res = {
|
||||
'op': 'ScaleShift',
|
||||
'type': 'ScaleShift',
|
||||
'axis': 0,
|
||||
'infer': copy_shape_infer,
|
||||
'weights': scale_blob,
|
||||
'biases': bias_blob,
|
||||
'embedded_inputs': [
|
||||
(1, 'weights', {
|
||||
'bin': 'weights'
|
||||
}),
|
||||
(2, 'biases', {
|
||||
'bin': 'biases'
|
||||
})
|
||||
]
|
||||
}
|
||||
self.assertDictEqual(exp_res, res)
|
Loading…
Reference in New Issue
Block a user