Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
This commit is contained in:
@@ -716,7 +716,6 @@ openvino/tools/mo/main_tf.py
|
||||
openvino/tools/mo/middle/__init__.py
|
||||
openvino/tools/mo/middle/AddFakeQuantizeFuse.py
|
||||
openvino/tools/mo/middle/AddIsCyclicAttribute.py
|
||||
openvino/tools/mo/middle/AddMeanScaleValues.py
|
||||
openvino/tools/mo/middle/ApplyNHWCtoNCHWpermutation.py
|
||||
openvino/tools/mo/middle/ApplyPermutations.py
|
||||
openvino/tools/mo/middle/ArgOpsToTopK.py
|
||||
@@ -811,7 +810,6 @@ openvino/tools/mo/middle/reverse_tensor_iterator.py
|
||||
openvino/tools/mo/middle/ReverseTransposeNormalization.py
|
||||
openvino/tools/mo/middle/ReverseV2ToReverseSequence.py
|
||||
openvino/tools/mo/middle/RNNSequenceNormalizeToIE.py
|
||||
openvino/tools/mo/middle/ScaleInput.py
|
||||
openvino/tools/mo/middle/SharedWeightsDuplication.py
|
||||
openvino/tools/mo/middle/SliceConverter.py
|
||||
openvino/tools/mo/middle/SliceLikeToStridedSlice.py
|
||||
|
||||
@@ -46,72 +46,6 @@ class ReverseChannels(Op):
|
||||
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
|
||||
|
||||
|
||||
class InsertReverseChannels(BackReplacementPattern):
|
||||
"""
|
||||
Searches for all suitable nodes with type=Parameter and inserts internal ReverseChannels op right after them
|
||||
TODO: we should provide user an ability to explicitly specify nodes for input channel reversing
|
||||
"""
|
||||
enabled = False
|
||||
|
||||
@staticmethod
|
||||
def get_suitable_channel_index(node: Node, shape):
|
||||
if len(shape) != 4:
|
||||
return None
|
||||
|
||||
guessed_layout = 'NCHW'
|
||||
if node.has_valid('rt_info'):
|
||||
rt_info = node.rt_info
|
||||
if rt_info.contains('old_api_map_order'):
|
||||
old_api_map_version = rt_info.get_attribute_version('old_api_map_order')
|
||||
old_api_map = rt_info.info['old_api_map_order', old_api_map_version]
|
||||
if 'inverse_order' in old_api_map.info:
|
||||
order = old_api_map.info['inverse_order']
|
||||
assert len(order) == len(guessed_layout)
|
||||
guessed_layout = np.array(list(guessed_layout))[order]
|
||||
guessed_layout = ''.join(guessed_layout)
|
||||
idx, has_layout = get_dim_from_layout(node, 'C')
|
||||
if not has_layout:
|
||||
idx = get_features_dim(guessed_layout, len(node.shape))
|
||||
if compatible_dims(shape[idx], 3):
|
||||
return idx
|
||||
else:
|
||||
return None
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
|
||||
for p in graph.get_op_nodes(type='Parameter')]
|
||||
suitable_params = []
|
||||
for name, p, shape in all_params:
|
||||
idx = self.get_suitable_channel_index(p, shape)
|
||||
if idx is not None:
|
||||
suitable_params.append((name, p, shape, idx))
|
||||
|
||||
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
|
||||
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape, _ in suitable_params}))
|
||||
if not len(suitable_params):
|
||||
raise Error('Network has {} inputs overall, but none of them are suitable for input channels reversing.\n'
|
||||
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels (in case of '
|
||||
'dynamic dimensions C channel must be provided in a layout for this input)\n'
|
||||
'All inputs: {}'.format(len(all_params), all_params))
|
||||
elif len(suitable_params) < len(all_params):
|
||||
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
|
||||
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n'
|
||||
'Suitable inputs {}'.format(len(all_params), len(suitable_params),
|
||||
{name: shape for name, _, shape in all_params},
|
||||
{name: shape for name, _, shape, _ in suitable_params}),
|
||||
extra={'is_warning': True})
|
||||
|
||||
for name, parameter, _, idx in suitable_params:
|
||||
reverse_index = int64_array(idx)
|
||||
|
||||
if parameter.out_port(0).disconnected():
|
||||
continue
|
||||
|
||||
reverse_channels = ReverseChannels(graph, {'name': name + '/reverse_input_channels',
|
||||
'axis': reverse_index}).create_node()
|
||||
parameter.out_port(0).get_connection().insert_node(reverse_channels, attributes_save_mode='source')
|
||||
|
||||
|
||||
class ReverseChannelsPropagationDown(BackReplacementPattern):
|
||||
"""
|
||||
Propagates ReverseChannels operations down through nodes that we have rules for
|
||||
@@ -544,7 +478,7 @@ class DecomposeReverseChannels(BackReplacementPattern):
|
||||
|
||||
class ApplyReverseChannels(BackReplacementPattern):
|
||||
"""
|
||||
Reverses input channels for suitable Parameter operation if requested by user
|
||||
Reverses input channels for suitable Parameter operation
|
||||
Optimizes channel reversing by fusion to Convolution weights if applicable
|
||||
"""
|
||||
enabled = True
|
||||
@@ -556,8 +490,6 @@ class ApplyReverseChannels(BackReplacementPattern):
|
||||
"""
|
||||
Following transformations should run in strict order, that is why we disabled them all and run here
|
||||
"""
|
||||
if graph.graph['cmd_params'].reverse_input_channels:
|
||||
InsertReverseChannels().find_and_replace_pattern(graph)
|
||||
ReverseChannelsPropagationDown().find_and_replace_pattern(graph)
|
||||
ReverseChannelsPropagationUp().find_and_replace_pattern(graph)
|
||||
DecomposeReverseChannels().find_and_replace_pattern(graph)
|
||||
|
||||
@@ -71,30 +71,9 @@ def apply_offline_transformations(input_model: str, argv: argparse.Namespace):
|
||||
|
||||
func = read_model(input_model + "_tmp.xml")
|
||||
|
||||
# TODO: use ngraph preprocessing (Mean/Scale/ReverseInputChannels) for legacy frontends
|
||||
reverse_input_channels = False
|
||||
if 'reverse_input_channels' in argv:
|
||||
reverse_input_channels = argv.reverse_input_channels
|
||||
argv.reverse_input_channels = False
|
||||
mean_scale_values = {}
|
||||
if 'mean_scale_values' in argv:
|
||||
mean_scale_values = argv.mean_scale_values
|
||||
argv.mean_scale_values = {}
|
||||
scale = None
|
||||
if 'scale' in argv:
|
||||
scale = argv.scale
|
||||
argv.scale = None
|
||||
|
||||
# Apply preprocessing for layouts only
|
||||
# Apply preprocessing (mean/scale/reverse_channels/convert_layout/etc)
|
||||
apply_preprocessing(ov_function=func, argv=argv)
|
||||
|
||||
if 'reverse_input_channels' in argv:
|
||||
argv.reverse_input_channels = reverse_input_channels
|
||||
if 'mean_scale_values' in argv:
|
||||
argv.mean_scale_values = mean_scale_values
|
||||
if 'scale' in argv:
|
||||
argv.scale = scale
|
||||
|
||||
apply_moc_transformations(func)
|
||||
|
||||
params_with_custom_types = create_params_with_custom_types(argv.packed_user_shapes)
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.layout import get_dim_from_layout, get_features_dim
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import compatible_dims
|
||||
from openvino.tools.mo.front.extractor import get_node_id_with_ports
|
||||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from openvino.tools.mo.graph.graph import Graph, Node
|
||||
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
||||
from openvino.tools.mo.ops.elementwise import Add, Mul
|
||||
from openvino.tools.mo.utils.cli_parser import get_node_name_with_port_from_input_value
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
from openvino.tools.mo.utils.utils import refer_to_faq_msg
|
||||
|
||||
|
||||
class AddMeanScaleValues(MiddleReplacementPattern):
|
||||
enabled = True
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
return []
|
||||
|
||||
def run_before(self):
|
||||
from openvino.tools.mo.middle.pass_separator import MiddleStart
|
||||
return [MiddleStart]
|
||||
|
||||
@staticmethod
|
||||
def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values: np.array,
|
||||
preprocessing_name: str):
|
||||
assert preprocessing_name in ['scale', 'mean']
|
||||
if node_mean_scale_values.get(preprocessing_name) is None:
|
||||
return
|
||||
user_value = node_mean_scale_values[preprocessing_name]
|
||||
value = 1 / user_value if preprocessing_name == 'scale' else user_value * (-1)
|
||||
optimize_value = int(preprocessing_name == 'scale')
|
||||
op = Mul if preprocessing_name == 'scale' else Add
|
||||
|
||||
if all([x == optimize_value for x in value]):
|
||||
return
|
||||
assert input_node.has_valid('shape')
|
||||
in_name = input_node.soft_get('name', input_node.id)
|
||||
features_dim_idx, has_layout = get_dim_from_layout(input_node, 'C')
|
||||
if features_dim_idx is None:
|
||||
if has_layout:
|
||||
log.warning('Layout for input {} doesn\'t have channel ("C") dimension to apply {} preprocessing. '
|
||||
'Skipping this input.'.format(in_name, preprocessing_name))
|
||||
features_dim_idx = get_features_dim(graph.graph['layout'], len(input_node.shape))
|
||||
assert compatible_dims(value.size, input_node.shape[features_dim_idx]) or value.size == 1, \
|
||||
"Incompatible layout, please specify correct layout for the node"
|
||||
|
||||
shape = np.ones(len(input_node.shape), dtype=np.int64)
|
||||
shape[features_dim_idx] = value.size
|
||||
value = value.reshape(shape)
|
||||
|
||||
if input_node.op == 'Parameter' and input_node.has_and_set('data_type'):
|
||||
dtype = input_node.data_type
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
value = value.astype(dtype)
|
||||
|
||||
name = in_name + '/' + preprocessing_name
|
||||
preprocessing = create_op_with_const_inputs(graph, op=op, port_value_dict={1: value}, op_attrs={'name': name})
|
||||
|
||||
if input_node.is_out_port_connected(0) and len(input_node.out_port(0).get_destinations()) == 1:
|
||||
# There are models with pattern Parameter(uint8) -> Convert(float).
|
||||
# Adding mean/scale leads to the following:
|
||||
# Parameter(uint8) -> Mean/Scale -> Convert(float) which is incorrect.
|
||||
# To fix this mean and scale preprocessing node is inserted after Convert(float) node.
|
||||
out_node = input_node.out_port(0).get_destination().node
|
||||
convert_type = out_node.soft_get('dst_type')
|
||||
if out_node.soft_get('type') == "Convert" and (convert_type in [np.float32, np.float16]):
|
||||
input_node = out_node
|
||||
if convert_type != value.dtype:
|
||||
new_value = value.astype(convert_type)
|
||||
const_node = preprocessing.in_port(1).get_connection().get_source().node
|
||||
const_node['value'] = new_value
|
||||
|
||||
for dst in input_node.out_port(0).get_destinations():
|
||||
if dst.node.soft_get('type') != 'ShapeOf':
|
||||
# After the insertion of additional operations model optimizer
|
||||
# should keep the link to the input layer. Parameter node in framework
|
||||
# should map to parameter node in IR.
|
||||
# For this reason 'fw_tensor_debug_info' should be kept in data node.
|
||||
dst.get_connection().set_source(preprocessing.out_port(0), "source")
|
||||
|
||||
input_node.out_port(0).connect(preprocessing.in_port(0))
|
||||
|
||||
@staticmethod
|
||||
def apply_scale(graph: Graph, input_node: Node, node_mean_scale_values: dict):
|
||||
AddMeanScaleValues.insert_pre_processing(graph, input_node, node_mean_scale_values, preprocessing_name='scale')
|
||||
|
||||
@staticmethod
|
||||
def apply_mean_value(graph: Graph, input_node: Node, node_mean_scale_values: dict):
|
||||
AddMeanScaleValues.insert_pre_processing(graph, input_node, node_mean_scale_values, preprocessing_name='mean')
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
values = graph.graph['cmd_params'].mean_scale_values
|
||||
input_nodes = graph.get_op_nodes(op='Parameter')
|
||||
|
||||
if not isinstance(values, dict):
|
||||
# The case when input names to apply mean/scales weren't specified
|
||||
if len(values) != len(input_nodes):
|
||||
raise Error('Numbers of inputs and mean/scale values do not match. ' + refer_to_faq_msg(61))
|
||||
|
||||
data = np.copy(values)
|
||||
values = {}
|
||||
for idx, node in enumerate(input_nodes):
|
||||
values.update(
|
||||
{
|
||||
node.soft_get('name', node.id): {
|
||||
'mean': data[idx][0],
|
||||
'scale': data[idx][1]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
for node_name, node_mean_scale_values in values.items():
|
||||
node_id = None
|
||||
node_name = get_node_name_with_port_from_input_value(node_name)
|
||||
try:
|
||||
node_id, direction, port = get_node_id_with_ports(graph, node_name, skip_if_no_port=False)
|
||||
assert direction != 'out', 'Only input port can be specified for mean/scale application'
|
||||
except Error as e:
|
||||
log.warning('node_name {} is not found in graph'.format(node_name))
|
||||
if Node(graph, node_id) not in input_nodes:
|
||||
# if the user cutted-off input of the network then input node name specified in the --scale_values
|
||||
# or --mean_values doesn't correspond to a real input node generated by Model Optimizer. But
|
||||
# the information about initial input node name is stored in Placeholder's attribute 'initial_node_name'
|
||||
new_node_id = None
|
||||
for placeholder in input_nodes:
|
||||
try:
|
||||
placeholder_port = int(placeholder.id.split("_")[-1])
|
||||
except Exception as ex:
|
||||
log.debug('Can not get the port number from the node {}'.format(placeholder.id))
|
||||
log.debug('Port will be defined as None')
|
||||
port = None
|
||||
if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_id and (
|
||||
port is None or placeholder_port == port):
|
||||
new_node_id = placeholder.id
|
||||
break
|
||||
if new_node_id is None:
|
||||
raise Error('Input with name {} wasn\'t found!'.format(node_name) +
|
||||
refer_to_faq_msg(83))
|
||||
node_id = new_node_id
|
||||
|
||||
input_node = Node(graph, node_id)
|
||||
AddMeanScaleValues.apply_scale(graph, input_node, node_mean_scale_values)
|
||||
AddMeanScaleValues.apply_mean_value(graph, input_node, node_mean_scale_values)
|
||||
@@ -16,8 +16,8 @@ class MiddleInputCut(MiddleReplacementPattern):
|
||||
return [PreMiddleStart]
|
||||
|
||||
def run_before(self):
|
||||
from openvino.tools.mo.middle.ScaleInput import ScaleInput
|
||||
return [ScaleInput]
|
||||
from openvino.tools.mo.middle.pass_separator import MiddleStart
|
||||
return [MiddleStart]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
add_input_ops(graph, graph.graph['user_shapes'], False)
|
||||
|
||||
@@ -10,8 +10,8 @@ class RemoveIdentity(MiddleReplacementPattern):
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from openvino.tools.mo.middle.AddMeanScaleValues import AddMeanScaleValues
|
||||
return [AddMeanScaleValues]
|
||||
from openvino.tools.mo.middle.InputCut import MiddleInputCut
|
||||
return [MiddleInputCut]
|
||||
|
||||
def run_before(self):
|
||||
from openvino.tools.mo.middle.pass_separator import MiddleStart
|
||||
@@ -31,8 +31,8 @@ class RemoveDropout(MiddleReplacementPattern):
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from openvino.tools.mo.middle.AddMeanScaleValues import AddMeanScaleValues
|
||||
return [AddMeanScaleValues]
|
||||
from openvino.tools.mo.middle.InputCut import MiddleInputCut
|
||||
return [MiddleInputCut]
|
||||
|
||||
def run_before(self):
|
||||
from openvino.tools.mo.middle.pass_separator import MiddleStart
|
||||
@@ -53,8 +53,8 @@ class RemoveNodesWithZeroPhase(MiddleReplacementPattern):
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
from openvino.tools.mo.middle.AddMeanScaleValues import AddMeanScaleValues
|
||||
return [AddMeanScaleValues]
|
||||
from openvino.tools.mo.middle.InputCut import MiddleInputCut
|
||||
return [MiddleInputCut]
|
||||
|
||||
def run_before(self):
|
||||
from openvino.tools.mo.middle.pass_separator import MiddleStart
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.middle.AddMeanScaleValues import AddMeanScaleValues
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import mo_array
|
||||
from openvino.tools.mo.graph.graph import Graph
|
||||
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
||||
|
||||
|
||||
class ScaleInput(MiddleReplacementPattern):
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from openvino.tools.mo.middle.pass_separator import PreMiddleStart
|
||||
return [PreMiddleStart]
|
||||
|
||||
def run_before(self):
|
||||
from openvino.tools.mo.middle.AddMeanScaleValues import AddMeanScaleValues
|
||||
return [AddMeanScaleValues]
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('placeholder', dict(kind='op', op='Parameter')),
|
||||
('data', dict(kind='data'))],
|
||||
edges=[
|
||||
('placeholder', 'data'),
|
||||
],
|
||||
)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
scale = graph.graph['cmd_params'].scale
|
||||
if scale is None or scale == 1:
|
||||
return
|
||||
assert (len(match['placeholder'].out_nodes()))
|
||||
|
||||
AddMeanScaleValues.apply_scale(graph, match['placeholder'], {'scale': mo_array([scale])})
|
||||
@@ -6,8 +6,7 @@ from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown, \
|
||||
InsertReverseChannels
|
||||
from openvino.tools.mo.back.ReverseInputChannels import ReverseChannelsPropagationUp, ReverseChannelsPropagationDown
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from openvino.tools.mo.graph.graph import Node, Graph
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
@@ -258,54 +257,3 @@ class ReverseInputChannelsTest(unittest.TestCase):
|
||||
reverse_channels = Node(graph, 'reverse_channels_down')
|
||||
self.assertTrue(reverse_channels.axis == 1)
|
||||
self.assertTrue(type(reverse_channels.axis) == np.ndarray)
|
||||
|
||||
def test_insert(self):
|
||||
graph = build_graph(get_nodes([1, 3, 10, 10]),
|
||||
[*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'),
|
||||
*connect('mul', 'result')], nodes_with_edges_only=True,
|
||||
cli=Namespace(reverse_input_channels=True))
|
||||
|
||||
InsertReverseChannels().find_and_replace_pattern(graph)
|
||||
graph_ref = build_graph(get_nodes([1, 3, 10, 10]),
|
||||
[*connect('placeholder1', 'reverse_channels'), *connect('reverse_channels', '0:mul'),
|
||||
*connect('placeholder2', '1:mul'), *connect('mul', 'result')])
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_insert_old_api_map(self):
|
||||
graph = build_graph(get_nodes([1, 10, 10, 3]),
|
||||
[*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'),
|
||||
*connect('mul', 'result')], nodes_with_edges_only=True,
|
||||
cli=Namespace(reverse_input_channels=True))
|
||||
|
||||
node = Node(graph, 'placeholder1')
|
||||
old_api_map = OldAPIMapOrder(version=0)
|
||||
node.rt_info.info[('old_api_map_order', old_api_map.get_version())] = old_api_map
|
||||
node.rt_info.info[('old_api_map_order', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1])
|
||||
|
||||
InsertReverseChannels().find_and_replace_pattern(graph)
|
||||
graph_ref = build_graph(get_nodes([1, 10, 10, 3], 3),
|
||||
[*connect('placeholder1', 'reverse_channels'), *connect('reverse_channels', '0:mul'),
|
||||
*connect('placeholder2', '1:mul'), *connect('mul', 'result')])
|
||||
|
||||
node2 = Node(graph_ref, 'placeholder1')
|
||||
node2.rt_info = node.rt_info
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_insert_layout(self):
|
||||
graph = build_graph(get_nodes([1, 10, 10, 3]),
|
||||
[*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'),
|
||||
*connect('mul', 'result')], nodes_with_edges_only=True,
|
||||
cli=Namespace(reverse_input_channels=True,
|
||||
layout_values={
|
||||
'placeholder1': {'source_layout': 'nhwc', 'target_layout': None}}))
|
||||
|
||||
InsertReverseChannels().find_and_replace_pattern(graph)
|
||||
graph_ref = build_graph(get_nodes([1, 10, 10, 3], 3),
|
||||
[*connect('placeholder1', 'reverse_channels'), *connect('reverse_channels', '0:mul'),
|
||||
*connect('placeholder2', '1:mul'), *connect('mul', 'result')])
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
@@ -1,476 +0,0 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.middle.AddMeanScaleValues import AddMeanScaleValues
|
||||
from openvino.tools.mo.middle.ScaleInput import ScaleInput
|
||||
from openvino.tools.mo.graph.graph import Graph, Node
|
||||
from openvino.tools.mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.mo.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, result, connect, connect_data, \
|
||||
valued_const_with_data
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('parameter', [1, 3, 227, 227],
|
||||
{'type': 'Parameter',
|
||||
'op': 'Parameter',
|
||||
'shape': [1, 3, 227, 227],
|
||||
'data_type': np.float32}),
|
||||
**regular_op_with_shaped_data('parameter_2', [1, 3, 227, 227],
|
||||
{'type': 'Parameter',
|
||||
'op': 'Parameter',
|
||||
'shape': [1, 3, 227, 227],
|
||||
'data_type': np.float32}),
|
||||
|
||||
**regular_op_with_shaped_data('mul_scale', [1, 3, 227, 227], {'type': 'Multiply', 'op': 'Mul'}),
|
||||
**regular_op_with_shaped_data('add_mean', [1, 3, 227, 227], {'type': 'Add', 'op': 'Add'}),
|
||||
**regular_op_with_shaped_data('convert', [1, 3, 227, 227], {'type': 'Convert',
|
||||
'op': 'Convert',
|
||||
'dst_type': np.float32}),
|
||||
|
||||
**valued_const_with_data('scale', np.array([1. / 1., 1. / 2., 1. / 3.]).reshape((1, 3, 1, 1))),
|
||||
**valued_const_with_data('mean', np.array([-1., -2., -3.]).reshape((1, 3, 1, 1))),
|
||||
|
||||
**regular_op_with_shaped_data('shape_of', [4], {'type': 'ShapeOf', 'op': 'ShapeOf'}),
|
||||
**regular_op_with_shaped_data('op', [1, 3, 227, 227], {}),
|
||||
**result('result'),
|
||||
**result('result_2'),
|
||||
}
|
||||
|
||||
|
||||
class AddMeanScaleValuesTest(UnitTestWithMockedTelemetry):
|
||||
def check_graph_attrs(self, graph: Graph, graph_ref: Graph, parameter_node_names: list):
|
||||
for node in graph.get_op_nodes():
|
||||
if node.soft_get('name') in parameter_node_names:
|
||||
self.assertTrue(node.soft_get('type') == 'Parameter')
|
||||
out_node = node.out_node(0)
|
||||
out_node_ref = Node(graph_ref, node.id).out_node(0)
|
||||
self.assertTrue(out_node['fw_tensor_debug_info'] == out_node_ref['fw_tensor_debug_info'])
|
||||
else:
|
||||
if node.soft_get('type') == 'Const':
|
||||
value = node.out_port(0).data.get_value()
|
||||
self.assertTrue(value.dtype == np.float32)
|
||||
if 0 in node.out_nodes():
|
||||
out_node = node.out_node(0)
|
||||
self.assertFalse('fw_tensor_debug_info' in out_node)
|
||||
|
||||
def set_graph_attrs(self, graph: Graph, parameter_node_names: list):
|
||||
for node in graph.get_op_nodes():
|
||||
if node.soft_get('name') in parameter_node_names:
|
||||
self.assertTrue(node.soft_get('type') == 'Parameter')
|
||||
out_node = node.out_node(0)
|
||||
out_node['fw_tensor_debug_info'] = ['fw_name', 0]
|
||||
|
||||
def test_mean_values_with_data_name(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
|
||||
mean_values = parse_tuple_pairs('(1,2,3)')
|
||||
scale_values = parse_tuple_pairs('')
|
||||
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
|
||||
argv = Namespace(mean_scale_values=mean_scale)
|
||||
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_without_data_name(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
], {'parameter': {'name': 'None'}})
|
||||
|
||||
mean_values = parse_tuple_pairs('(1,2,3)')
|
||||
scale_values = parse_tuple_pairs('')
|
||||
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
|
||||
argv = Namespace(mean_scale_values=mean_scale)
|
||||
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'None'}},
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['None'])
|
||||
self.set_graph_attrs(graph_ref, ['None'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['None'])
|
||||
|
||||
def test_mean_values_explicit_and_optimized(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
*connect('parameter_2', 'result_2'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'parameter': {'mean': np.array([1., 2., 3.])},
|
||||
'parameter_2': {'mean': np.array([0., 0., 0.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_optimized(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_optimized_and_scale_values_explicit(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(
|
||||
mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_explicit(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
|
||||
"""
|
||||
Test case when user cutted start of the network and specified mean/scale value to the new input node 'node_3'.
|
||||
"""
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
|
||||
*connect('parameter_2', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'op'),
|
||||
*connect('op', 'result_2'),
|
||||
])
|
||||
|
||||
argv = Namespace(
|
||||
mean_scale_values={'parameter': {'mean': np.array([1, 2, 3])}, 'op': {'scale': np.array([1, 2, 3])}})
|
||||
graph = build_graph(
|
||||
nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')],
|
||||
{'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
|
||||
graph_ref = build_graph(nodes,
|
||||
[
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
*connect_data('parameter', 'shape_of'),
|
||||
*connect('shape_of', 'result_2'),
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
argv = Namespace(
|
||||
mean_scale_values={'parameter': {'mean': np.array([1, 2, 3]), 'scale': np.array([1, 2, 3])}})
|
||||
graph = build_graph(nodes,
|
||||
[
|
||||
*connect('parameter', 'result'),
|
||||
*connect_data('parameter', 'shape_of'),
|
||||
*connect('shape_of', 'result_2'),
|
||||
],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_with_colon_in_node_name(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'param:0': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'param:0'}},
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_mean_values_with_colon_in_node_name_and_port(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'0:param:0': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||
{'parameter': {'name': 'param:0', 'id': 'param:0/placeholder_0',
|
||||
'initial_node_name': 'param:0'}},
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_scale_input(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)},
|
||||
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)}})
|
||||
|
||||
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255))
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
ScaleInput().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_scale_input_2(self):
|
||||
graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
|
||||
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
ScaleInput().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_debug_info_absence(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, [])
|
||||
|
||||
def test_insert_add_mean_scale_after_convert(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', 'convert'),
|
||||
*connect('convert', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
|
||||
graph = build_graph(nodes, [*connect('parameter', 'convert'), *connect('convert', 'result')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, [])
|
||||
|
||||
def test_insert_add_mean_scale_after_convert_different_type(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', 'convert'),
|
||||
*connect('convert', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]),
|
||||
np.array([1., 2., 3.])]])
|
||||
graph = build_graph(nodes, [*connect('parameter', 'convert'), *connect('convert', 'result')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, [])
|
||||
add_node = graph.get_op_nodes(type="Add")[0]
|
||||
self.assertTrue(add_node.in_port(1).get_connection().get_source().node['value'].dtype == np.float32)
|
||||
|
||||
def test_mean_values_explicit_and_optimized_layout(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
*connect('parameter_2', 'result_2'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'parameter': {'mean': np.array([1., 2., 3.])},
|
||||
'parameter_2': {'mean': np.array([0., 0., 0.])}},
|
||||
layout_values={'parameter': {'source_layout': 'nchw', 'target_layout': None},
|
||||
'parameter_2': {'source_layout': 'nchw', 'target_layout': None}}
|
||||
)
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
|
||||
graph.graph['layout'] = 'NHWC'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_optimized_layout(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}},
|
||||
layout_values={'': {'source_layout': 'nchw', 'target_layout': None}}
|
||||
)
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NHWC'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_optimized_and_scale_values_explicit_layout(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(
|
||||
mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}},
|
||||
layout_values={'': {'source_layout': 'nchw', 'target_layout': None}}
|
||||
)
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NHWC'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_explicit_layout(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]],
|
||||
layout_values={'': {'source_layout': 'nchw', 'target_layout': None}}
|
||||
)
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NHWC'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
Reference in New Issue
Block a user