Add layout commands in MO (#8829)
* Add layout support in MO * Apply review feedback
This commit is contained in:
parent
2ff4ef2e4f
commit
51947eeb3d
@ -37,10 +37,57 @@ def compress_model(func: object):
|
||||
from openvino.offline_transformations_pybind import compress_model_transformation # pylint: disable=import-error,no-name-in-module
|
||||
compress_model_transformation(func)
|
||||
|
||||
def apply_offline_transformations(input_model: str, framework: str, transforms: list, compress_fp16=False):
|
||||
|
||||
def add_layouts(ov_function, argv: argparse.Namespace):
|
||||
from openvino.preprocess import PrePostProcessor # pylint: disable=no-name-in-module,import-error
|
||||
from openvino.runtime import Layout # pylint: disable=import-error,no-name-in-module
|
||||
|
||||
prep = PrePostProcessor(ov_function)
|
||||
layout_values = argv.layout_values
|
||||
if '' in layout_values:
|
||||
if len(ov_function.inputs) == 1:
|
||||
layout_values = {
|
||||
list(ov_function.input().get_tensor().get_names())[0]: {
|
||||
'source_layout': layout_values[''].get('source_layout'),
|
||||
'target_layout': layout_values[''].get('target_layout')
|
||||
}
|
||||
}
|
||||
else:
|
||||
input_names = [list(ov_input.get_tensor().get_names())[0] for ov_input in ov_function.inputs]
|
||||
raise Error('Layout without name can be specified for models with only one input, '
|
||||
'but provided model has {} inputs: \'{}\'. '
|
||||
'Please specify explicitly input/output name for --layout option'
|
||||
.format(len(input_names), input_names))
|
||||
|
||||
set_layout_names = set(layout_values.keys())
|
||||
for idx, ov_input in enumerate(ov_function.inputs):
|
||||
found = set.intersection(set(ov_input.get_tensor().get_names()), set_layout_names)
|
||||
assert len(found) <= 1, 'More then one name point to the same node'
|
||||
if len(found) == 1:
|
||||
node_name = list(found)[0]
|
||||
found_layout = layout_values[node_name]
|
||||
if found_layout['source_layout']:
|
||||
prep.input(node_name).network().set_layout(Layout(found_layout['source_layout']))
|
||||
if found_layout['target_layout']:
|
||||
prep.input(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
|
||||
|
||||
for idx, ov_output in enumerate(ov_function.outputs):
|
||||
found = set.intersection(set(ov_output.get_tensor().get_names()), set_layout_names)
|
||||
assert len(found) <= 1, 'More then one name point to the same node'
|
||||
if len(found) == 1:
|
||||
node_name = list(found)[0]
|
||||
found_layout = layout_values[node_name]
|
||||
if found_layout['source_layout']:
|
||||
prep.output(node_name).network().set_layout(Layout(found_layout['source_layout']))
|
||||
if found_layout['target_layout']:
|
||||
prep.output(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
|
||||
prep.build()
|
||||
|
||||
|
||||
def apply_offline_transformations(input_model: str, argv: argparse.Namespace):
|
||||
# This variable is only needed by GenerateMappingFile transformation
|
||||
# to produce correct mapping
|
||||
extract_names = framework in ['tf', 'mxnet', 'kaldi']
|
||||
extract_names = argv.framework in ['tf', 'mxnet', 'kaldi']
|
||||
|
||||
from openvino.offline_transformations_pybind import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module
|
||||
from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error
|
||||
@ -57,24 +104,14 @@ def apply_offline_transformations(input_model: str, framework: str, transforms:
|
||||
|
||||
func = read_model(input_model + "_tmp.xml")
|
||||
|
||||
apply_user_transformations(func, transforms)
|
||||
add_layouts(func, argv) # TODO: replace with preprocessing
|
||||
|
||||
apply_user_transformations(func, parse_transform(argv.transform))
|
||||
apply_moc_transformations(func)
|
||||
|
||||
if compress_fp16:
|
||||
if "compress_fp16" in argv and argv.compress_fp16:
|
||||
compress_model(func)
|
||||
|
||||
serialize(func, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8'))
|
||||
path_to_mapping = input_model + ".mapping"
|
||||
generate_mapping_file(func, path_to_mapping.encode('utf-8'), extract_names)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_model")
|
||||
parser.add_argument("--framework")
|
||||
parser.add_argument("--transform")
|
||||
parser.add_argument("--compress_fp16", action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
apply_offline_transformations(args.input_model, args.framework, parse_transform(args.transform), args.compress_fp16)
|
||||
|
@ -6,7 +6,6 @@ import datetime
|
||||
import logging as log
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
@ -28,10 +27,10 @@ from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_
|
||||
from openvino.tools.mo.pipeline.common import prepare_emit_ir, get_ir_version
|
||||
from openvino.tools.mo.pipeline.unified import unified_pipeline
|
||||
from openvino.tools.mo.utils import import_extensions
|
||||
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_model_name, \
|
||||
get_common_cli_options, get_caffe_cli_options, get_tf_cli_options, get_mxnet_cli_options, get_kaldi_cli_options, \
|
||||
get_onnx_cli_options, get_mean_scale_dictionary, parse_tuple_pairs, get_freeze_placeholder_values, get_meta_info, \
|
||||
parse_transform, check_available_transforms
|
||||
from openvino.tools.mo.utils.cli_parser import check_available_transforms, get_caffe_cli_options, \
|
||||
get_common_cli_options, get_freeze_placeholder_values, get_kaldi_cli_options, get_layout_values, \
|
||||
get_mean_scale_dictionary, get_meta_info, get_model_name, get_mxnet_cli_options, get_onnx_cli_options, \
|
||||
get_placeholder_shapes, get_tf_cli_options, get_tuple_values, parse_transform, parse_tuple_pairs
|
||||
from openvino.tools.mo.utils.error import Error, FrameworkError
|
||||
from openvino.tools.mo.utils.find_ie_version import find_ie_version
|
||||
from openvino.tools.mo.utils.get_ov_update_message import get_ov_update_message
|
||||
@ -268,6 +267,7 @@ def arguments_post_parsing(argv: argparse.Namespace):
|
||||
scale_values = parse_tuple_pairs(argv.scale_values)
|
||||
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
|
||||
argv.mean_scale_values = mean_scale
|
||||
argv.layout_values = get_layout_values(argv.layout, argv.source_layout, argv.target_layout)
|
||||
|
||||
if not os.path.exists(argv.output_dir):
|
||||
try:
|
||||
@ -360,22 +360,14 @@ def emit_ir(graph: Graph, argv: argparse.Namespace):
|
||||
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
|
||||
|
||||
return_code = "not executed"
|
||||
# This try-except is additional reinsurance that the IE
|
||||
# dependency search does not break the MO pipeline
|
||||
try:
|
||||
if not argv.legacy_ir_generation:
|
||||
path_to_offline_transformations = os.path.join(os.path.realpath(os.path.dirname(__file__)), 'back',
|
||||
'offline_transformations.py')
|
||||
cmd = [sys.executable, path_to_offline_transformations,
|
||||
"--input_model", orig_model_name,
|
||||
"--framework", argv.framework,
|
||||
"--transform", argv.transform]
|
||||
from openvino.tools.mo.back.offline_transformations import apply_offline_transformations
|
||||
apply_offline_transformations(orig_model_name, argv)
|
||||
if "compress_fp16" in argv and argv.compress_fp16:
|
||||
cmd += ["--compress_fp16"]
|
||||
# restore data_type cmd parameter
|
||||
argv.data_type = 'FP16'
|
||||
status = subprocess.run(cmd, env=os.environ)
|
||||
return_code = status.returncode
|
||||
return_code = 0
|
||||
except Exception as e:
|
||||
return_code = "failed"
|
||||
log.error(e)
|
||||
|
@ -309,6 +309,28 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
||||
'The exact meaning and order ' +
|
||||
'of channels depend on how the original model was trained.',
|
||||
default=())
|
||||
common_group.add_argument('--source_layout',
|
||||
help='Layout of the input or output of the model in the framework. Layout can'
|
||||
' be specified in the short form, e.g. nhwc, or in complex form, e.g. [n,h,w,c].'
|
||||
' Example for many names: '
|
||||
'in_name1([n,h,w,c]),in_name2(nc),out_name1(n),out_name2(nc). Layout can be '
|
||||
'partially defined, "?" can be used to specify undefined layout for one dimension, '
|
||||
'"..." can be used to specify undefined layout for multiple dimensions, for example '
|
||||
'?c??, nc..., n...c, etc.',
|
||||
default=())
|
||||
common_group.add_argument('--target_layout',
|
||||
help='Same as --source_layout, but specifies target layout that will be in the model '
|
||||
'after processing by ModelOptimizer.',
|
||||
default=())
|
||||
common_group.add_argument('--layout',
|
||||
help='Combination of --source_layout and --target_layout. Can\'t be used with either of '
|
||||
'them. If model has one input it is sufficient to specify layout of this input, for'
|
||||
' example --layout nhwc. To specify layouts of many tensors, names must be provided,'
|
||||
' for example: --layout name1(nchw),name2(nc). It is possible to instruct '
|
||||
'ModelOptimizer to change layout, for example: '
|
||||
'--layout name1(nhwc->nchw),name2(cn->nc). Also "*" in long layout form can be used'
|
||||
' to fuse dimensions, for example [n,c,...]->[n*c,…].',
|
||||
default=())
|
||||
# TODO: isn't it a weights precision type
|
||||
common_group.add_argument('--data_type',
|
||||
help='Data type for all intermediate tensors and weights. ' +
|
||||
@ -417,6 +439,9 @@ def get_common_cli_options(model_name):
|
||||
d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model']
|
||||
d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model']
|
||||
d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model']
|
||||
d['source_layout'] = ['- Source layout', lambda x: x if x else 'Not specified']
|
||||
d['target_layout'] = ['- Target layout', lambda x: x if x else 'Not specified']
|
||||
d['layout'] = ['- Layout', lambda x: x if x else 'Not specified']
|
||||
d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified']
|
||||
d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified']
|
||||
d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified']
|
||||
@ -835,6 +860,137 @@ def parse_input_value(input_value: str):
|
||||
return node_name, shape, value, data_type
|
||||
|
||||
|
||||
def split_str_avoiding_square_brackets(s: str) -> list:
|
||||
"""
|
||||
Splits a string by comma, but skips commas inside square brackets.
|
||||
:param s: string to split
|
||||
:return: list of strings split by comma
|
||||
"""
|
||||
res = list()
|
||||
skipping = 0
|
||||
last_idx = 0
|
||||
for i, c in enumerate(s):
|
||||
if c == '[':
|
||||
skipping += 1
|
||||
elif c == ']':
|
||||
skipping -= 1
|
||||
elif c == ',' and skipping == 0:
|
||||
res.append(s[last_idx:i])
|
||||
last_idx = i + 1
|
||||
res.append(s[last_idx:])
|
||||
return res
|
||||
|
||||
|
||||
def split_layouts_by_arrow(s: str) -> tuple:
|
||||
"""
|
||||
Splits a layout string by first arrow (->).
|
||||
:param s: string to split
|
||||
:return: tuple containing source and target layouts
|
||||
"""
|
||||
arrow = s.find('->')
|
||||
if arrow != -1:
|
||||
source_layout = s[:arrow]
|
||||
target_layout = s[arrow + 2:]
|
||||
if source_layout == '':
|
||||
source_layout = None
|
||||
if target_layout == '':
|
||||
target_layout = None
|
||||
return source_layout, target_layout
|
||||
else:
|
||||
return s, None
|
||||
|
||||
|
||||
def validate_layout(layout: str):
|
||||
"""
|
||||
Checks if layout is of valid format.
|
||||
:param layout: string containing layout
|
||||
:raises: if layout is incorrect
|
||||
"""
|
||||
valid_layout_re = re.compile(r'\[?[^\[\]\(\)\s]*\]?')
|
||||
if layout and not valid_layout_re.fullmatch(layout):
|
||||
raise Error('Invalid layout parsed: {}'.format(layout))
|
||||
|
||||
|
||||
def write_found_layout(name: str, found_layout: str, parsed: dict, dest: str = None):
|
||||
"""
|
||||
Writes found layout data to the 'parsed' dict.
|
||||
:param name: name of the node to add layout
|
||||
:param found_layout: string containing layout for the node
|
||||
:param parsed: dict where result will be stored
|
||||
:param dest: type of the command line:
|
||||
* 'source' is --source_layout
|
||||
* 'target' is --target_layout
|
||||
* None is --layout
|
||||
"""
|
||||
s_layout = None
|
||||
t_layout = None
|
||||
if name in parsed:
|
||||
s_layout = parsed[name]['source_layout']
|
||||
t_layout = parsed[name]['target_layout']
|
||||
if dest == 'source':
|
||||
s_layout = found_layout
|
||||
elif dest == 'target':
|
||||
t_layout = found_layout
|
||||
else:
|
||||
s_layout, t_layout = split_layouts_by_arrow(found_layout)
|
||||
validate_layout(s_layout)
|
||||
validate_layout(t_layout)
|
||||
parsed[name] = {'source_layout': s_layout, 'target_layout': t_layout}
|
||||
|
||||
|
||||
def parse_layouts_by_destination(s: str, parsed: dict, dest: str = None) -> None:
|
||||
"""
|
||||
Parses layout command line to get all names and layouts from it. Adds all found data in the 'parsed' dict.
|
||||
:param s: string to parse
|
||||
:param parsed: dict where result will be stored
|
||||
:param dest: type of the command line:
|
||||
* 'source' is --source_layout
|
||||
* 'target' is --target_layout
|
||||
* None is --layout
|
||||
"""
|
||||
list_s = split_str_avoiding_square_brackets(s)
|
||||
if len(list_s) == 1 and (list_s[0][-1] not in ')]' or (list_s[0][0] == '[' and list_s[0][-1] == ']')):
|
||||
# single layout case
|
||||
write_found_layout('', list_s[0], parsed, dest)
|
||||
else:
|
||||
for layout_str in list_s:
|
||||
# case for: "name1(nhwc->[n,c,h,w])"
|
||||
p1 = re.compile(r'(\S+)\((\S+)\)')
|
||||
m1 = p1.match(layout_str)
|
||||
# case for: "name1[n,h,w,c]->[n,c,h,w]"
|
||||
p2 = re.compile(r'(\S+)(\[\S*\])')
|
||||
m2 = p2.match(layout_str)
|
||||
if m1:
|
||||
found_g = m1.groups()
|
||||
elif m2:
|
||||
found_g = m2.groups()
|
||||
else:
|
||||
raise Error("More then one layout provided for --{}layout without providing name.".format(
|
||||
dest + '_' if dest else ''))
|
||||
write_found_layout(found_g[0], found_g[1], parsed, dest)
|
||||
|
||||
|
||||
def get_layout_values(argv_layout: str = '', argv_source_layout: str = '', argv_target_layout: str = ''):
|
||||
"""
|
||||
Parses layout string.
|
||||
:param argv_layout: string with a list of layouts passed as a --layout.
|
||||
:param argv_source_layout: string with a list of layouts passed as a --source_layout.
|
||||
:param argv_target_layout: string with a list of layouts passed as a --target_layout.
|
||||
:return: dict with names and layouts associated
|
||||
"""
|
||||
if argv_layout and (argv_source_layout or argv_target_layout):
|
||||
raise Error("--layout is used as well as --source_layout and/or --target_layout which is not allowed, please "
|
||||
"use one of them.")
|
||||
res = {}
|
||||
if argv_layout:
|
||||
parse_layouts_by_destination(argv_layout, res)
|
||||
if argv_source_layout:
|
||||
parse_layouts_by_destination(argv_source_layout, res, 'source')
|
||||
if argv_target_layout:
|
||||
parse_layouts_by_destination(argv_target_layout, res, 'target')
|
||||
return res
|
||||
|
||||
|
||||
def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_value: str):
|
||||
"""
|
||||
Parses values for placeholder freezing and input node names
|
||||
|
@ -56,6 +56,9 @@ def replaceArgsHelper(log_level='DEBUG',
|
||||
batch=None,
|
||||
mean_values=None,
|
||||
scale_values=None,
|
||||
layout=None,
|
||||
source_layout=None,
|
||||
target_layout=None,
|
||||
output_dir='.',
|
||||
freeze_placeholder_with_value=None):
|
||||
return argparse.Namespace(
|
||||
@ -72,6 +75,9 @@ def replaceArgsHelper(log_level='DEBUG',
|
||||
batch=batch,
|
||||
mean_values=mean_values,
|
||||
scale_values=scale_values,
|
||||
layout=layout,
|
||||
source_layout=source_layout,
|
||||
target_layout=target_layout,
|
||||
output_dir=output_dir,
|
||||
freeze_placeholder_with_value=freeze_placeholder_with_value,
|
||||
use_legacy_frontend=None,
|
||||
|
@ -12,9 +12,10 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
|
||||
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \
|
||||
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \
|
||||
get_model_name, \
|
||||
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
|
||||
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms
|
||||
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
|
||||
|
||||
@ -380,6 +381,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
def test_input_without_values(self):
|
||||
self.assertRaises(Error, parse_tuple_pairs, "input1,input2")
|
||||
|
||||
|
||||
class TestSingleTupleParsing(unittest.TestCase):
|
||||
def test_get_values_ideal(self):
|
||||
values = "(1.11, 22.22, 333.333)"
|
||||
@ -476,7 +478,8 @@ class TestShapesParsing(unittest.TestCase):
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
|
||||
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
@ -492,7 +495,8 @@ class TestShapesParsing(unittest.TestCase):
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
|
||||
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
@ -510,8 +514,10 @@ class TestShapesParsing(unittest.TestCase):
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input,
|
||||
argv_freeze_placeholder_with_value)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
|
||||
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'], ),
|
||||
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
|
||||
@ -772,6 +778,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
input_shapes = "(12,4,1),(4,-6,8)"
|
||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||
|
||||
|
||||
class TestModelNameParsing(unittest.TestCase):
|
||||
def test_model_name_ideal(self):
|
||||
model_name = '/home/models/mymodel.caffemodel'
|
||||
@ -925,7 +932,7 @@ class TransformChecker(unittest.TestCase):
|
||||
"DummyPass2[types=ReLU,PReLU;values=1,2,3]"),
|
||||
[("LowLatency2", {"use_const_initializer": [True, False]}),
|
||||
("DummyPass1", {}),
|
||||
("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1,2,3]})])
|
||||
("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1, 2, 3]})])
|
||||
|
||||
def test_multiple_passes_no_args(self):
|
||||
self.assertEqual(parse_transform("DummyPass,LowLatency22"),
|
||||
@ -967,3 +974,297 @@ class TransformChecker(unittest.TestCase):
|
||||
def test_check_dummy_pass_is_available(self, available_transformations):
|
||||
available_transformations.return_value = {"LowLatency2": None}
|
||||
self.assertRaises(Error, check_available_transforms, [("DummyPass", "")])
|
||||
|
||||
|
||||
class TestLayoutParsing(unittest.TestCase):
|
||||
def test_get_layout_1(self):
|
||||
argv_layout = "name1([n,h,w,c]),name2([n,h,w,c]->[n,c,h,w])"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
|
||||
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_2(self):
|
||||
argv_layout = "name1(nhwc),name2(nhwc->nchw)"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_3(self):
|
||||
argv_layout = "name1(n...c),name2(n...c->nc...)"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'name1': {'source_layout': 'n...c', 'target_layout': None},
|
||||
'name2': {'source_layout': 'n...c', 'target_layout': 'nc...'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_4(self):
|
||||
argv_layout = "nhwc"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_5(self):
|
||||
argv_layout = "[n,h,w,c]"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_6(self):
|
||||
argv_layout = "nhwc->nchw"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_7(self):
|
||||
argv_layout = "[n,h,w,c]->[n,c,h,w]"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_scalar(self):
|
||||
argv_layout = "name1(nhwc),name2([])"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||
'name2': {'source_layout': '[]', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_layout_1(self):
|
||||
argv_source_layout = "[n,h,w,c]"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_layout_2(self):
|
||||
argv_source_layout = "nhwc"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_layout_3(self):
|
||||
argv_source_layout = "name1(nhwc),name2(nchw)"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||
'name2': {'source_layout': 'nchw', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_layout_4(self):
|
||||
argv_source_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
|
||||
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_layout_5(self):
|
||||
argv_source_layout = "name1(nhwc),name2([n,c,h,w])"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_layout_6(self):
|
||||
argv_source_layout = "name1(nhwc),name2[n,c,h,w]"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_layout_scalar(self):
|
||||
argv_source_layout = "name1(nhwc),name2([])"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||
'name2': {'source_layout': '[]', 'target_layout': None}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_target_layout_1(self):
|
||||
argv_target_layout = "[n,h,w,c]"
|
||||
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
exp_res = {'': {'source_layout': None, 'target_layout': '[n,h,w,c]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_target_layout_2(self):
|
||||
argv_target_layout = "nhwc"
|
||||
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
exp_res = {'': {'source_layout': None, 'target_layout': 'nhwc'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_target_layout_3(self):
|
||||
argv_target_layout = "name1(nhwc),name2(nchw)"
|
||||
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||
'name2': {'source_layout': None, 'target_layout': 'nchw'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_target_layout_4(self):
|
||||
argv_target_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
|
||||
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': None, 'target_layout': '[n,h,w,c]'},
|
||||
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_target_layout_5(self):
|
||||
argv_target_layout = "name1(nhwc),name2([n,c,h,w])"
|
||||
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_target_layout_6(self):
|
||||
argv_target_layout = "name1(nhwc),name2[n,c,h,w]"
|
||||
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_target_layout_scalar(self):
|
||||
argv_target_layout = "name1(nhwc),name2[]"
|
||||
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||
'name2': {'source_layout': None, 'target_layout': '[]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_1(self):
|
||||
argv_source_layout = "[n,h,w,c]"
|
||||
argv_target_layout = "[n,c,h,w]"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_2(self):
|
||||
argv_source_layout = "nhwc"
|
||||
argv_target_layout = "nchw"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_3(self):
|
||||
argv_source_layout = "name1(nhwc),name2(nhwc)"
|
||||
argv_target_layout = "name1(nchw),name2(nchw)"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
|
||||
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_4(self):
|
||||
argv_source_layout = "name1([n,h,w,c]),name2([n,h,w,c])"
|
||||
argv_target_layout = "name1([n,c,h,w]),name2([n,c,h,w])"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'},
|
||||
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_5(self):
|
||||
argv_source_layout = "name1(nhwc),name2[n,h,w,c]"
|
||||
argv_target_layout = "name1(nchw),name2[n,c,h,w]"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
|
||||
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_scalar(self):
|
||||
argv_source_layout = "name1(nhwc),name2[]"
|
||||
argv_target_layout = "name1(nchw),name2[]"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
|
||||
'name2': {'source_layout': '[]', 'target_layout': '[]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_raises_if_layout_and_source_layout_provided(self):
|
||||
argv_layout = "nhwc"
|
||||
argv_source_layout = "nhwc"
|
||||
with self.assertRaises(Error):
|
||||
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout)
|
||||
|
||||
def test_get_layout_raises_if_layout_and_target_layout_provided(self):
|
||||
argv_layout = "nhwc->nchw"
|
||||
argv_target_layout = "nchw"
|
||||
with self.assertRaises(Error):
|
||||
get_layout_values(argv_layout=argv_layout, argv_target_layout=argv_target_layout)
|
||||
|
||||
def test_get_layout_raises_if_layout_with_source_and_target_layout_provided(self):
|
||||
argv_layout = "nhwc->nchw"
|
||||
argv_source_layout = "nhwc"
|
||||
argv_target_layout = "nchw"
|
||||
with self.assertRaises(Error):
|
||||
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout,
|
||||
argv_target_layout=argv_target_layout)
|
||||
|
||||
def test_get_layout_raises_incorrect_format(self):
|
||||
argv_layout = "name[n,h,w,c]->nchw"
|
||||
with self.assertRaises(Error):
|
||||
res = get_layout_values(argv_layout=argv_layout)
|
||||
print(res)
|
||||
|
||||
def test_get_layout_raises_multiple_layouts_without_names(self):
|
||||
argv_layout = "nhwc->nchw,nhwc->nchw"
|
||||
with self.assertRaises(Error):
|
||||
res = get_layout_values(argv_layout=argv_layout)
|
||||
print(res)
|
||||
|
||||
def test_get_layout_raises_multiple_layouts_without_names_source_layout(self):
|
||||
argv_source_layout = "nhwc,nhwc"
|
||||
with self.assertRaises(Error):
|
||||
res = get_layout_values(argv_source_layout=argv_source_layout)
|
||||
print(res)
|
||||
|
||||
def test_get_layout_raises_multiple_layouts_without_names_target_layout(self):
|
||||
argv_target_layout = "nchw,nchw"
|
||||
with self.assertRaises(Error):
|
||||
res = get_layout_values(argv_target_layout=argv_target_layout)
|
||||
print(res)
|
||||
|
Loading…
Reference in New Issue
Block a user