Re-implement caffe old-style extractors with extractor extensions (#3675)

* move crop extractor

* Add concat_ext.py

* Add roipooling_ext.py

* Add roipooling_ext

* Add scale extractor

* Add scale extractor

* Add bn_ext.py and dropout_ext.py

* Add bn_ext.py and dropout_ext.py

* Add bn_ext.py and dropout_ext.py

* Fix bn.ext.py

* Sort fix

* Fix bn_test.py

* rename to batchnorm_ext

* Add bn_ext

* Fix batchnorm_ext.py

* small fix

* Small fix
This commit is contained in:
Eugeny Volosenkov 2021-02-01 13:17:17 +03:00 committed by GitHub
parent a6a5635a59
commit 1a787cb3ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 253 additions and 491 deletions

View File

@ -71,15 +71,20 @@ extensions/front/caffe/accum_ext.py
extensions/front/caffe/argmax_ext.py
extensions/front/caffe/ArgMaxFlatten.py
extensions/front/caffe/axpy.py
extensions/front/caffe/batchnorm_ext.py
extensions/front/caffe/binarization.py
extensions/front/caffe/binary_conv_ext.py
extensions/front/caffe/bn.py
extensions/front/caffe/bn_ext.py
extensions/front/caffe/concat_ext.py
extensions/front/caffe/conv_ext.py
extensions/front/caffe/correlation_ext.py
extensions/front/caffe/crop_ext.py
extensions/front/caffe/ctcgreedydecoder_ext.py
extensions/front/caffe/CustomLayersMapping.xml.example
extensions/front/caffe/data_augmentation_ext.py
extensions/front/caffe/detection_output.py
extensions/front/caffe/dropout_ext.py
extensions/front/caffe/elementwise_ext.py
extensions/front/caffe/eltwise_add_normalize.py
extensions/front/caffe/elu.py
@ -106,6 +111,8 @@ extensions/front/caffe/relu_ext.py
extensions/front/caffe/reorgyolo_ext.py
extensions/front/caffe/resample_ext.py
extensions/front/caffe/reshape.py
extensions/front/caffe/roipooling_ext.py
extensions/front/caffe/scale_ext.py
extensions/front/caffe/shufflechannel_ext.py
extensions/front/caffe/sigmoid.py
extensions/front/caffe/simplernms_ext.py
@ -618,6 +625,7 @@ extensions/ops/axpy.py
extensions/ops/BatchNormInference.py
extensions/ops/binarization.py
extensions/ops/BlockLSTM.py
extensions/ops/BN.py
extensions/ops/box_nms.py
extensions/ops/bucketize.py
extensions/ops/Cast.py
@ -758,12 +766,7 @@ mo/front/caffe/collect_attributes.py
mo/front/caffe/custom_layers_mapping.py
mo/front/caffe/extractor.py
mo/front/caffe/extractors/__init__.py
mo/front/caffe/extractors/batchnorm.py
mo/front/caffe/extractors/concat.py
mo/front/caffe/extractors/crop.py
mo/front/caffe/extractors/native_caffe.py
mo/front/caffe/extractors/roipooling.py
mo/front/caffe/extractors/scale.py
mo/front/caffe/extractors/tile.py
mo/front/caffe/extractors/utils.py
mo/front/caffe/loader.py

View File

@ -0,0 +1,53 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.BatchNormInference import BatchNormInference
from mo.front.caffe.extractors.utils import embed_input
from mo.front.extractor import FrontExtractorOp
class BatchNormalizationExtractor(FrontExtractorOp):
op = 'batchnorm'
enabled = True
@classmethod
def extract(cls, node):
eps = node.pb.batch_norm_param.eps
attrs = {
'eps': eps
}
pb_model = None if not node.soft_get('model_pb', None) else node.model_pb
if pb_model:
blobs = pb_model.blobs
assert len(blobs) >= 2, 'BatchNorm accepts not less then two input blobs'
mean = np.array(blobs[0].data)
variance = np.array(blobs[1].data)
if len(blobs) == 3:
scale = blobs[2].data[0]
if scale != 0:
scale = 1.0 / scale
mean *= scale
variance *= scale
embed_input(attrs, 1, 'gamma', np.ones(mean.shape), 'gamma')
embed_input(attrs, 2, 'beta', np.zeros(variance.shape), 'beta')
embed_input(attrs, 3, 'mean', mean, 'biases')
embed_input(attrs, 4, 'variance', variance, 'weights')
BatchNormInference.update_node_stat(node, attrs)
return cls.enabled

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -27,7 +27,7 @@ class BNToScaleShift(FrontReplacementOp):
"""
Replaces BN layer with ScaleShift.
"""
op = "batchNormInference"
op = "BN"
enabled = True
def replace_op(self, graph: Graph, node: Node):
@ -35,6 +35,7 @@ class BNToScaleShift(FrontReplacementOp):
param = graph.node[node.id]['pb'].bn_param
pb_model = graph.node[node.id]['model_pb']
blobs = pb_model.blobs
if len(blobs) != 4:

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,12 +14,15 @@
limitations under the License.
"""
from mo.front.common.partial_infer.concat import concat_infer
from extensions.ops.BN import BN
from mo.front.extractor import FrontExtractorOp
def concat_ext(pb_layer, pb_model):
return {
'type': "Concat",
'axis': pb_layer.concat_param.axis,
'infer': concat_infer
}
class BNExtractor(FrontExtractorOp):
op = 'BN'
enabled = True
@classmethod
def extract(cls, node):
BN.update_node_stat(node, {})
return cls.enabled

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,9 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import unittest
from extensions.front.caffe.bn import BNToScaleShift
from mo.graph.graph import Node
@ -47,7 +46,7 @@ class TestBNReplacer(unittest.TestCase):
FakeParam('data', shift)])
nodes = [
('input', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
('bn', {'type': None, 'kind': 'op', 'op': 'batchNormInference', 'pb': bn_pb, 'model_pb': bn_bin}),
('bn', {'type': None, 'kind': 'op', 'op': 'BN', 'pb': bn_pb, 'model_pb': bn_bin}),
('output', {'kind': 'op', 'type': 'Identity', 'op': 'Identity'}),
]
edges = [

View File

@ -0,0 +1,32 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.extractor import FrontExtractorOp
from mo.ops.concat import Concat
class ConcatFrontExtractor(FrontExtractorOp):
op = 'concat'
enabled = True
@classmethod
def extract(cls, node):
pb = node.pb
mapping_rule = {
'axis': pb.concat_param.axis,
}
Concat.update_node_stat(node, mapping_rule)
return cls.enabled

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,7 +17,7 @@
import unittest
from unittest.mock import patch
from mo.front.caffe.extractors.crop import CropFrontExtractor
from extensions.front.caffe.crop_ext import CropFrontExtractor
from mo.front.common.partial_infer.crop import crop_infer
from mo.ops.crop import Crop
from mo.ops.op import Op

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,15 +14,16 @@
limitations under the License.
"""
from mo.front.common.partial_infer.roipooling import roipooling_infer
from extensions.ops.identity import Identity
from mo.front.extractor import FrontExtractorOp
from mo.graph.graph import Node
def roipooling_ext(proto_layer, model_layer):
param = proto_layer.roi_pooling_param
return {
'type': 'ROIPooling',
'pooled_h': param.pooled_h,
'pooled_w': param.pooled_w,
'spatial_scale': param.spatial_scale,
'infer': roipooling_infer
}
class DropoutFrontExtractor(FrontExtractorOp):
op = 'dropout'
enabled = True
@classmethod
def extract(cls, node: Node):
Identity.update_node_stat(node, {})
return cls.enabled

View File

@ -0,0 +1,35 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.extractor import FrontExtractorOp
from mo.ops.roipooling import ROIPooling
class ROIPoolingFrontExtractor(FrontExtractorOp):
op = 'roipooling'
enabled = True
@classmethod
def extract(cls, node):
param = node.pb.roi_pooling_param
attrs = {
'pooled_h': param.pooled_h,
'pooled_w': param.pooled_w,
'spatial_scale': param.spatial_scale,
}
ROIPooling.update_node_stat(node, attrs)
return cls.enabled

View File

@ -0,0 +1,55 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.caffe.extractors.utils import embed_input, weights_biases
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.front.extractor import FrontExtractorOp
from mo.ops.scale_shift import ScaleShiftOp
from mo.utils.utils import NamedAttrsClass
class ScaleFrontExtractor(FrontExtractorOp):
op = 'scale'
enabled = True
@classmethod
def extract(cls, node):
pb = node.pb
model = node.model_pb
param = pb.scale_param
attrs = {
'axis': param.axis,
}
if model is None and len(pb.bottom) == 1:
# default weights and biases for scale layer if the caffemodel file doesn't contain them
model = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([1])}),
NamedAttrsClass({'data': np.array([0])})])})
# scale with 1 input and 1 or 2 blobs
if model and len(model.blobs) != 0 and len(pb.bottom) == 1:
attrs.update(weights_biases(param.bias_term, model))
# 2 inputs + bias
elif len(pb.bottom) == 2 and param.bias_term:
if model is None or len(model.blobs) == 0:
# default bias for scale layer with 2 inputs if the caffemodel file doesn't contain them
model = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([0])})])})
embed_input(attrs, 1, 'biases', model.blobs[0].data)
ScaleShiftOp.update_node_stat(node, attrs)
return cls.enabled

View File

@ -0,0 +1,35 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.graph.graph import Graph
from mo.ops.op import Op
class BN(Op):
"""
BN operation comes from caffe and will be replaced by BNToScaleShift FrontReplacer.
"""
op = 'BN'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': None,
'op': self.op,
'in_ports_count': 5,
'out_ports_count': 1,
'infer': None
}, attrs)

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,11 +14,7 @@
limitations under the License.
"""
from mo.front.caffe.extractors.batchnorm import batch_norm_ext
from mo.front.caffe.extractors.concat import concat_ext
from mo.front.caffe.extractors.native_caffe import native_caffe_node_extractor
from mo.front.caffe.extractors.roipooling import roipooling_ext
from mo.front.caffe.extractors.scale import scale_ext
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.front.common.register_custom_ops import extension_op_extractor
from mo.front.extractor import CaffePythonFrontExtractorOp
@ -36,22 +32,8 @@ def node_pb_arg(pb_extractor):
Keys are names that appear as layer names in .prototxt.
Full list is available here: http://caffe.berkeleyvision.org/tutorial/layers.html
"""
caffe_type_extractors = {
# Common Layers
'dropout': node_pb_arg(lambda _, __: dict(op='Dropout', infer=copy_shape_infer)),
# Normalization Layers
'batchnorm': node_pb_arg(batch_norm_ext),
# Activation Layers
'scale': node_pb_arg(scale_ext),
# Utility Layers
'concat': node_pb_arg(concat_ext),
# Custom, implemented in IE, Fast-RCNN-specific
'roipooling': node_pb_arg(roipooling_ext),
}
caffe_type_extractors = {}
def common_caffe_fields(node: Node) -> dict:
@ -62,6 +44,7 @@ def common_caffe_fields(node: Node) -> dict:
if isinstance(layer_type, int):
layer_type = pb.LayerType.DESCRIPTOR.values_by_number[layer_type].name
layer_type = str(layer_type)
return {
'kind': 'op',
'name': pb.name,

View File

@ -1,63 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.caffe.extractors.utils import embed_input
from mo.front.common.partial_infer.elemental import copy_shape_infer
def batch_norm_ext(pb_layer, pb_model):
"""
Extracts properties of the BatchNorm layer.
In case of scale, scale is merged into mean and variance
Args:
pb_layer: proto layer, contains own properties of the layer, i.e epsilon
pb_model: caffemodel layer, contains blobs with 0: mean, 1: variance, (opt)2: scale
Returns:
attrs object with type, partial inference function and mean/variance properties.
"""
assert pb_layer, 'Protobuf layer can not be empty'
param = pb_layer.batch_norm_param
attrs = {
'op': 'BatchNormalization',
'type': 'BatchNormalization',
'eps': param.eps,
'infer': copy_shape_infer
}
if not pb_model:
return attrs
blobs = pb_model.blobs
assert len(blobs) >= 2, 'BatchNorm accepts not less then two input blobs'
mean = np.array(blobs[0].data)
variance = np.array(blobs[1].data)
if len(blobs) == 3:
scale = blobs[2].data[0]
if scale != 0:
scale = 1.0 / scale
mean *= scale
variance *= scale
embed_input(attrs, 1, 'gamma', np.ones(mean.shape), 'gamma')
embed_input(attrs, 2, 'beta', np.zeros(variance.shape), 'beta')
embed_input(attrs, 3, 'mean', mean, 'biases')
embed_input(attrs, 4, 'variance', variance, 'weights')
return attrs

View File

@ -1,147 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from mo.front.caffe.extractors.batchnorm import batch_norm_ext
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.utils.unittest.extractors import FakeParam, FakeModelLayer
class FakeBNProtoLayer:
def __init__(self, eps):
self.batch_norm_param = FakeParam('eps', eps)
class TestShapesParsing(unittest.TestCase):
def test_bn_ext_no_ml_no_pb(self):
self.assertRaises(AssertionError, batch_norm_ext, None, None)
def test_bn_ext_no_ml(self):
res = batch_norm_ext(FakeBNProtoLayer(10), None)
exp_res = {
'op': 'BatchNormalization',
'type': 'BatchNormalization',
'eps': 10,
'infer': copy_shape_infer
}
self.assertEqual(res, exp_res)
def test_bn_ext_ml_one_blob(self):
self.assertRaises(AssertionError, batch_norm_ext, FakeBNProtoLayer(10), FakeModelLayer([np.array([1, 2])]))
def test_bn_ext_ml_two_blobs(self):
mean_blob = np.array([1., 2.])
variance_blob = np.array([3., 4.])
blobs = [mean_blob, variance_blob]
res = batch_norm_ext(FakeBNProtoLayer(10),
FakeModelLayer(blobs))
exp_res = {
'type': 'BatchNormalization',
'eps': 10,
'infer': copy_shape_infer,
'mean': mean_blob,
'variance': variance_blob,
'embedded_inputs': [
(1, 'gamma', {
'bin': 'gamma'
}),
(2, 'beta', {
'bin': 'beta'
}),
(3, 'mean', {
'bin': 'biases'
}),
(4, 'variance', {
'bin': 'weights'
})
]
}
for i in exp_res:
if i in ('mean', 'variance'):
np.testing.assert_array_equal(res[i], exp_res[i])
else:
self.assertEqual(res[i], exp_res[i])
def test_bn_ext_ml_three_blobs(self):
mean_blob = np.array([1., 2.])
variance_blob = np.array([3., 4.])
scale_blob = np.array([5., ])
blobs = [mean_blob, variance_blob, scale_blob]
res = batch_norm_ext(FakeBNProtoLayer(10),
FakeModelLayer(blobs))
exp_res = {
'type': 'BatchNormalization',
'eps': 10,
'infer': copy_shape_infer,
'mean': mean_blob * 0.2,
'variance': variance_blob * 0.2,
'embedded_inputs': [
(1, 'gamma', {
'bin': 'gamma'
}),
(2, 'beta', {
'bin': 'beta'
}),
(3, 'mean', {
'bin': 'biases'
}),
(4, 'variance', {
'bin': 'weights'
})
]
}
for i in exp_res:
if i in ('mean', 'variance'):
np.testing.assert_array_equal(res[i], exp_res[i])
else:
self.assertEqual(res[i], exp_res[i])
def test_bn_ext_ml_three_blobs_zero_scale(self):
mean_blob = np.array([1., 2.])
variance_blob = np.array([3., 4.])
scale_blob = np.array([0., ])
blobs = [mean_blob, variance_blob, scale_blob]
res = batch_norm_ext(FakeBNProtoLayer(10),
FakeModelLayer(blobs))
exp_res = {
'type': 'BatchNormalization',
'eps': 10,
'infer': copy_shape_infer,
'mean': mean_blob * 0.,
'variance': variance_blob * 0.,
'embedded_inputs': [
(1, 'gamma', {
'bin': 'gamma'
}),
(2, 'beta', {
'bin': 'beta'
}),
(3, 'mean', {
'bin': 'biases'
}),
(4, 'variance', {
'bin': 'weights'
})
]
}
for i in exp_res:
if i in ('mean', 'variance'):
np.testing.assert_array_equal(res[i], exp_res[i])
else:
self.assertEqual(res[i], exp_res[i])

View File

@ -1,37 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from mo.front.caffe.extractors.concat import concat_ext
from mo.front.common.partial_infer.concat import concat_infer
from mo.utils.unittest.extractors import FakeParam
class FakeProtoLayer:
def __init__(self, axis):
self.concat_param = FakeParam('axis', axis)
class TestConcat(unittest.TestCase):
def test_concat(self):
res = concat_ext(FakeProtoLayer(10), None)
exp_res = {
'axis': 10,
'infer': concat_infer,
'type': 'Concat'
}
self.assertEqual(res, exp_res)

View File

@ -1,47 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.caffe.extractors.utils import embed_input, weights_biases
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.utils.utils import NamedAttrsClass
def scale_ext(pl, ml):
param = pl.scale_param
attrs = {
'op': 'ScaleShift',
'type': 'ScaleShift',
'axis': param.axis,
'infer': copy_shape_infer
}
if ml is None and len(pl.bottom) == 1:
# default weights and biases for scale layer if the caffemodel file doesn't contain them
ml = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([1])}),
NamedAttrsClass({'data': np.array([0])})])})
# scale with 1 input and 1 or 2 blobs
if ml and len(ml.blobs) != 0 and len(pl.bottom) == 1:
attrs.update(weights_biases(param.bias_term, ml))
# 2 inputs + bias
elif len(pl.bottom) == 2 and param.bias_term:
if ml is None or len(ml.blobs) == 0:
# default bias for scale layer with 2 inputs if the caffemodel file doesn't contain them
ml = NamedAttrsClass({'blobs': np.array([NamedAttrsClass({'data': np.array([0])})])})
embed_input(attrs, 1, 'biases', ml.blobs[0].data)
return attrs

View File

@ -1,144 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from mo.front.caffe.extractors.scale import scale_ext
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.utils.unittest.extractors import FakeMultiParam, FakeModelLayer
class FakeProtoLayer:
def __init__(self, val, bottom2=False):
self.scale_param = val
if bottom2:
self.bottom = {"bottom1", "bottom2"}
else:
self.bottom = {"bottom1"}
class TestScale(unittest.TestCase):
def test_scale_ext(self):
mean_blob = np.array([1., 2.])
variance_blob = np.array([3., 4.])
blobs = [mean_blob, variance_blob]
params = {
'type': 'Scale',
'axis': 0,
'bias_term': True
}
res = scale_ext(FakeProtoLayer(FakeMultiParam(params)), FakeModelLayer(blobs))
exp_res = {
'op': 'ScaleShift',
'type': 'ScaleShift',
'axis': 0,
'infer': copy_shape_infer,
'weights': mean_blob,
'biases': variance_blob,
'embedded_inputs': [
(1, 'weights', {
'bin': 'weights'
}),
(2, 'biases', {
'bin': 'biases'
})
]
}
for i in exp_res:
if i in ('weights', 'biases'):
np.testing.assert_array_equal(res[i], exp_res[i])
else:
self.assertEqual(res[i], exp_res[i])
def test_scale_2inputs_ext(self):
params = {
'type': 'Scale',
'axis': 0,
'bias_term': False
}
res = scale_ext(FakeProtoLayer(FakeMultiParam(params), True), None)
exp_res = {
'op': 'ScaleShift',
'type': 'ScaleShift',
'axis': 0,
'infer': copy_shape_infer,
}
for i in exp_res:
self.assertEqual(res[i], exp_res[i])
def test_scale_2inputs_bias_ext(self):
variance_blob = np.array([3., 4.])
blobs = [variance_blob]
params = {
'type': 'Scale',
'axis': 0,
'bias_term': True
}
res = scale_ext(FakeProtoLayer(FakeMultiParam(params), True), FakeModelLayer(blobs))
exp_res = {
'op': 'ScaleShift',
'type': 'ScaleShift',
'axis': 0,
'infer': copy_shape_infer,
'biases': variance_blob,
'embedded_inputs': [
(1, 'biases', {
'bin': 'biases'
})]
}
for i in exp_res:
if i in ('biases'):
np.testing.assert_array_equal(res[i], exp_res[i])
else:
self.assertEqual(res[i], exp_res[i])
def test_create_default_weights(self):
"""
There are situations when scale layer doesn't have weights and biases. This test checks that if they are not
available in the caffemodel file then default values [1] and [0] are generated.
"""
scale_blob = np.array([1])
bias_blob = np.array([0])
params = {
'type': 'Scale',
'axis': 0,
'bias_term': True
}
res = scale_ext(FakeProtoLayer(FakeMultiParam(params)), None)
exp_res = {
'op': 'ScaleShift',
'type': 'ScaleShift',
'axis': 0,
'infer': copy_shape_infer,
'weights': scale_blob,
'biases': bias_blob,
'embedded_inputs': [
(1, 'weights', {
'bin': 'weights'
}),
(2, 'biases', {
'bin': 'biases'
})
]
}
self.assertDictEqual(exp_res, res)