Refactored legacy code for mean_scale_values transformations (#1936)
* Remove move_to_preproc. Not actual. * Updated documentation
This commit is contained in:
parent
b683b5501d
commit
9939253fed
@ -109,7 +109,6 @@ Framework-agnostic parameters:
|
|||||||
--disable_gfusing Turn off fusing of grouped convolutions
|
--disable_gfusing Turn off fusing of grouped convolutions
|
||||||
--enable_concat_optimization
|
--enable_concat_optimization
|
||||||
Turn on Concat optimization.
|
Turn on Concat optimization.
|
||||||
--move_to_preprocess Move mean values to IR preprocess section
|
|
||||||
--extensions EXTENSIONS
|
--extensions EXTENSIONS
|
||||||
Directory or a comma separated list of directories
|
Directory or a comma separated list of directories
|
||||||
with extensions. To disable all extensions including
|
with extensions. To disable all extensions including
|
||||||
|
@ -872,7 +872,6 @@ mo/middle/passes/fusing/mark_unfused_nodes.py
|
|||||||
mo/middle/passes/fusing/resnet_optimization.py
|
mo/middle/passes/fusing/resnet_optimization.py
|
||||||
mo/middle/passes/infer.py
|
mo/middle/passes/infer.py
|
||||||
mo/middle/passes/leaky_relu.py
|
mo/middle/passes/leaky_relu.py
|
||||||
mo/middle/passes/mean_scale_values.py
|
|
||||||
mo/middle/passes/tensor_names.py
|
mo/middle/passes/tensor_names.py
|
||||||
mo/middle/pattern_match.py
|
mo/middle/pattern_match.py
|
||||||
mo/middle/replacement.py
|
mo/middle/replacement.py
|
||||||
|
@ -16,36 +16,19 @@
|
|||||||
from extensions.middle.LeakyReluPattern import LeakyReLU
|
from extensions.middle.LeakyReluPattern import LeakyReLU
|
||||||
from extensions.middle.pass_separator import PostMiddleStart
|
from extensions.middle.pass_separator import PostMiddleStart
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
|
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
from mo.utils.find_inputs import find_inputs
|
from mo.utils.find_inputs import find_inputs
|
||||||
from mo.utils.utils import refer_to_faq_msg
|
from mo.utils.utils import refer_to_faq_msg
|
||||||
|
|
||||||
|
|
||||||
class Preprocessing(MiddleReplacementPattern):
|
|
||||||
enabled = True
|
|
||||||
force_clean_up = True
|
|
||||||
|
|
||||||
def run_after(self):
|
|
||||||
return [LeakyReLU]
|
|
||||||
|
|
||||||
def run_before(self):
|
|
||||||
return [PostMiddleStart]
|
|
||||||
|
|
||||||
def find_and_replace_pattern(self, graph: Graph):
|
|
||||||
argv = graph.graph['cmd_params']
|
|
||||||
if argv.move_to_preprocess:
|
|
||||||
move_scaleshift_to_preprocess(graph)
|
|
||||||
|
|
||||||
|
|
||||||
class CaffeMeanFileProcessing(MiddleReplacementPattern):
|
class CaffeMeanFileProcessing(MiddleReplacementPattern):
|
||||||
enabled = True
|
enabled = True
|
||||||
force_clean_up = True
|
force_clean_up = True
|
||||||
graph_condition = [lambda graph: graph.graph['fw'] == 'caffe']
|
graph_condition = [lambda graph: graph.graph['fw'] == 'caffe']
|
||||||
|
|
||||||
def run_after(self):
|
def run_after(self):
|
||||||
return [Preprocessing]
|
return [LeakyReLU]
|
||||||
|
|
||||||
def run_before(self):
|
def run_before(self):
|
||||||
return [PostMiddleStart]
|
return [PostMiddleStart]
|
||||||
|
@ -1,81 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright (C) 2018-2020 Intel Corporation
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mo.graph.graph import Graph
|
|
||||||
from mo.middle.pattern_match import apply_pattern
|
|
||||||
|
|
||||||
|
|
||||||
def move_scaleshift_to_preprocess_action(graph, match):
|
|
||||||
mean_values = {}
|
|
||||||
input_op = match['input_op']
|
|
||||||
scale_shift = match['scale_shift']
|
|
||||||
weights = np.squeeze(match['weights'].value)
|
|
||||||
biases = np.squeeze(match['biases'].value)
|
|
||||||
|
|
||||||
if graph.graph['cmd_params'].reverse_input_channels:
|
|
||||||
biases = np.flip(biases)
|
|
||||||
|
|
||||||
if any([x != 1 for x in weights]):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Keep biases (mean values) for current input as graph attr and remove ScaleShift layer
|
|
||||||
# Input->data->ScaleShift->scsh_data => Input->scsh_data
|
|
||||||
graph.remove_edge(input_op.id, input_op.out_node().id)
|
|
||||||
graph.add_edge(input_op.id, scale_shift.out_node().id, out=0)
|
|
||||||
graph.remove_edge(scale_shift.id, scale_shift.out_node().id)
|
|
||||||
|
|
||||||
# If bias contains zeros we just remove it
|
|
||||||
if all([x == 0 for x in biases]):
|
|
||||||
return
|
|
||||||
|
|
||||||
# In pre-process section, mean_values are subtracted
|
|
||||||
biases *= -1
|
|
||||||
|
|
||||||
mean_values.update({input_op.name: np.array(biases)})
|
|
||||||
|
|
||||||
# Add graph attribute 'mean_values' that stores mean_values per input if exists
|
|
||||||
if graph.graph.get('mean_values', None):
|
|
||||||
graph.graph['mean_values'].update(mean_values)
|
|
||||||
else:
|
|
||||||
graph.graph['mean_values'] = mean_values
|
|
||||||
|
|
||||||
|
|
||||||
def move_scaleshift_to_preprocess(graph: Graph):
|
|
||||||
"""
|
|
||||||
This function finds scaleshift layer after input layer and if it has weights with ones, it deletes scaleshift layer
|
|
||||||
and creates graph dict attribute : {'input':np.array(...), 'input2': ... }
|
|
||||||
"""
|
|
||||||
apply_pattern(
|
|
||||||
graph,
|
|
||||||
nodes=[
|
|
||||||
('weights', dict(kind='data')),
|
|
||||||
('biases', dict(kind='data')),
|
|
||||||
('input_output', dict(kind='data')),
|
|
||||||
('scsh_output', dict(kind='data')),
|
|
||||||
('input_op', dict(kind='op', type='Parameter')),
|
|
||||||
('scale_shift', dict(kind='op', type='ScaleShift')),
|
|
||||||
],
|
|
||||||
edges=[
|
|
||||||
('input_op', 'input_output'),
|
|
||||||
('scale_shift', 'scsh_output'),
|
|
||||||
('input_output', 'scale_shift', {'in': 0}),
|
|
||||||
('weights', 'scale_shift', {'in': 1}),
|
|
||||||
('biases', 'scale_shift', {'in': 2}),
|
|
||||||
],
|
|
||||||
action=move_scaleshift_to_preprocess_action
|
|
||||||
)
|
|
@ -1,174 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright (C) 2018-2020 Intel Corporation
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
from argparse import Namespace
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
|
|
||||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
|
||||||
from mo.utils.unittest.graph import build_graph
|
|
||||||
|
|
||||||
nodes_attributes = {'node_1': {'type': 'Identity', 'value': None, 'kind': 'op'},
|
|
||||||
'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
|
|
||||||
'concat': {'type': 'Concat', 'value': None, 'kind': 'op'},
|
|
||||||
'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
|
|
||||||
# Placeholders
|
|
||||||
'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
||||||
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
|
||||||
'placeholder_2': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
||||||
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
|
||||||
# ScaleShift layer
|
|
||||||
'scaleshift_1': {'type': 'ScaleShift', 'value': None, 'kind': 'op', 'op': 'ScaleShift'},
|
|
||||||
'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
|
|
||||||
'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
|
|
||||||
'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
|
||||||
'op_output': { 'kind': 'op', 'op': 'Result'},
|
|
||||||
'op_output_1': { 'kind': 'op', 'op': 'Result'}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TestScaleShift_To_Preprocess(unittest.TestCase):
|
|
||||||
def test_move_scaleshift_to_preprocess_1(self):
|
|
||||||
graph = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'placeholder_1_data'),
|
|
||||||
('placeholder_1_data', 'scaleshift_1'),
|
|
||||||
('scaleshift_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_w', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_b', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_data', 'op_output')
|
|
||||||
],
|
|
||||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
|
||||||
'scaleshift_1_w': {'shape': np.array([3]), 'value': np.ones(3)},
|
|
||||||
'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([-1, -2, -3])},
|
|
||||||
})
|
|
||||||
graph.graph['cmd_params'] = Namespace(reverse_input_channels=False)
|
|
||||||
del graph['placeholder_1']['placeholder_1_data'][0]['in']
|
|
||||||
del graph['scaleshift_1']['scaleshift_1_data'][0]['in']
|
|
||||||
|
|
||||||
graph_ref = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_data', 'op_output')
|
|
||||||
])
|
|
||||||
|
|
||||||
move_scaleshift_to_preprocess(graph)
|
|
||||||
self.assertTrue(graph.graph['mean_values'] is not None)
|
|
||||||
self.assertTrue(np.array_equal(graph.graph['mean_values']['placeholder_1'], np.array([1, 2, 3])))
|
|
||||||
|
|
||||||
(flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
|
|
||||||
self.assertTrue(flag, resp)
|
|
||||||
|
|
||||||
def test_move_scaleshift_to_preprocess_2(self):
|
|
||||||
graph = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'placeholder_1_data'),
|
|
||||||
('placeholder_1_data', 'scaleshift_1'),
|
|
||||||
('scaleshift_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_w', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_b', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_data', 'op_output'),
|
|
||||||
('placeholder_1_data', 'op_output_1')
|
|
||||||
],
|
|
||||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
|
||||||
'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array((1, 2, 3))},
|
|
||||||
'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([-1, -2, -3])},
|
|
||||||
})
|
|
||||||
graph.graph['cmd_params'] = Namespace(reverse_input_channels=False)
|
|
||||||
del graph['placeholder_1']['placeholder_1_data'][0]['in']
|
|
||||||
del graph['scaleshift_1']['scaleshift_1_data'][0]['in']
|
|
||||||
|
|
||||||
graph_ref = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'placeholder_1_data'),
|
|
||||||
('placeholder_1_data', 'scaleshift_1'),
|
|
||||||
('scaleshift_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_w', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_b', 'scaleshift_1'),
|
|
||||||
('placeholder_1_data', 'op_output_1'),
|
|
||||||
('scaleshift_1_data', 'op_output')
|
|
||||||
],
|
|
||||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
|
||||||
'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array((1, 2, 3))},
|
|
||||||
'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([-1, -2, -3])},
|
|
||||||
})
|
|
||||||
|
|
||||||
move_scaleshift_to_preprocess(graph)
|
|
||||||
self.assertTrue(graph.graph.get('mean_values', None) is None)
|
|
||||||
|
|
||||||
(flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
|
|
||||||
self.assertTrue(flag, resp)
|
|
||||||
|
|
||||||
def test_move_scaleshift_to_preprocess_3(self):
|
|
||||||
graph = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'placeholder_1_data'),
|
|
||||||
('placeholder_1_data', 'scaleshift_1'),
|
|
||||||
('scaleshift_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_w', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_data', 'op_output'),
|
|
||||||
('placeholder_1_data', 'op_output_1')
|
|
||||||
],
|
|
||||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
|
||||||
'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array((1, 2, 3))},
|
|
||||||
})
|
|
||||||
graph.graph['cmd_params'] = Namespace(reverse_input_channels=False)
|
|
||||||
del graph['placeholder_1']['placeholder_1_data'][0]['in']
|
|
||||||
del graph['scaleshift_1']['scaleshift_1_data'][0]['in']
|
|
||||||
|
|
||||||
graph_ref = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'placeholder_1_data'),
|
|
||||||
('placeholder_1_data', 'scaleshift_1'),
|
|
||||||
('scaleshift_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_w', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_data', 'op_output'),
|
|
||||||
('placeholder_1_data', 'op_output_1')
|
|
||||||
],
|
|
||||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
|
||||||
'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array((1, 2, 3))},
|
|
||||||
})
|
|
||||||
|
|
||||||
move_scaleshift_to_preprocess(graph)
|
|
||||||
self.assertTrue(graph.graph.get('mean_values', None) == None)
|
|
||||||
|
|
||||||
(flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
|
|
||||||
self.assertTrue(flag, resp)
|
|
||||||
|
|
||||||
def test_move_scaleshift_to_preprocess_4(self):
|
|
||||||
graph = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'placeholder_1_data'),
|
|
||||||
('placeholder_1_data', 'scaleshift_1'),
|
|
||||||
('scaleshift_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_w', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_b', 'scaleshift_1'),
|
|
||||||
('scaleshift_1_data', 'op_output')
|
|
||||||
],
|
|
||||||
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
|
|
||||||
'scaleshift_1_w': {'shape': np.array([3]), 'value': np.ones(3)},
|
|
||||||
'scaleshift_1_b': {'shape': np.array([3]), 'value': np.zeros(3)},
|
|
||||||
})
|
|
||||||
graph.graph['cmd_params'] = Namespace(reverse_input_channels=False)
|
|
||||||
del graph['placeholder_1']['placeholder_1_data'][0]['in']
|
|
||||||
del graph['scaleshift_1']['scaleshift_1_data'][0]['in']
|
|
||||||
|
|
||||||
graph_ref = build_graph(nodes_attributes,
|
|
||||||
[('placeholder_1', 'scaleshift_1_data'),
|
|
||||||
('scaleshift_1_data', 'op_output')
|
|
||||||
])
|
|
||||||
|
|
||||||
move_scaleshift_to_preprocess(graph)
|
|
||||||
self.assertTrue(graph.graph.get('mean_values', None) is None)
|
|
||||||
|
|
||||||
(flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
|
|
||||||
self.assertTrue(flag, resp)
|
|
@ -283,7 +283,7 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
|||||||
action='store_true')
|
action='store_true')
|
||||||
common_group.add_argument('--move_to_preprocess',
|
common_group.add_argument('--move_to_preprocess',
|
||||||
help='Move mean values to IR preprocess section',
|
help='Move mean values to IR preprocess section',
|
||||||
action='store_true')
|
action=DeprecatedStoreTrue)
|
||||||
# we use CanonicalizeDirCheckExistenceAction instead of readable_dirs to handle empty strings
|
# we use CanonicalizeDirCheckExistenceAction instead of readable_dirs to handle empty strings
|
||||||
common_group.add_argument("--extensions",
|
common_group.add_argument("--extensions",
|
||||||
help="Directory or a comma separated list of directories with extensions. To disable all "
|
help="Directory or a comma separated list of directories with extensions. To disable all "
|
||||||
|
Loading…
Reference in New Issue
Block a user