Add layout commands in MO (#8829)

* Add layout support in MO

* Apply review feedback
This commit is contained in:
Maxim Vafin 2021-12-13 13:57:19 +03:00 committed by GitHub
parent 2ff4ef2e4f
commit 51947eeb3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 533 additions and 41 deletions

View File

@ -37,10 +37,57 @@ def compress_model(func: object):
from openvino.offline_transformations_pybind import compress_model_transformation # pylint: disable=import-error,no-name-in-module
compress_model_transformation(func)
def apply_offline_transformations(input_model: str, framework: str, transforms: list, compress_fp16=False):
def add_layouts(ov_function, argv: argparse.Namespace):
from openvino.preprocess import PrePostProcessor # pylint: disable=no-name-in-module,import-error
from openvino.runtime import Layout # pylint: disable=import-error,no-name-in-module
prep = PrePostProcessor(ov_function)
layout_values = argv.layout_values
if '' in layout_values:
if len(ov_function.inputs) == 1:
layout_values = {
list(ov_function.input().get_tensor().get_names())[0]: {
'source_layout': layout_values[''].get('source_layout'),
'target_layout': layout_values[''].get('target_layout')
}
}
else:
input_names = [list(ov_input.get_tensor().get_names())[0] for ov_input in ov_function.inputs]
raise Error('Layout without name can be specified for models with only one input, '
'but provided model has {} inputs: \'{}\'. '
'Please specify explicitly input/output name for --layout option'
.format(len(input_names), input_names))
set_layout_names = set(layout_values.keys())
for idx, ov_input in enumerate(ov_function.inputs):
found = set.intersection(set(ov_input.get_tensor().get_names()), set_layout_names)
assert len(found) <= 1, 'More then one name point to the same node'
if len(found) == 1:
node_name = list(found)[0]
found_layout = layout_values[node_name]
if found_layout['source_layout']:
prep.input(node_name).network().set_layout(Layout(found_layout['source_layout']))
if found_layout['target_layout']:
prep.input(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
for idx, ov_output in enumerate(ov_function.outputs):
found = set.intersection(set(ov_output.get_tensor().get_names()), set_layout_names)
assert len(found) <= 1, 'More then one name point to the same node'
if len(found) == 1:
node_name = list(found)[0]
found_layout = layout_values[node_name]
if found_layout['source_layout']:
prep.output(node_name).network().set_layout(Layout(found_layout['source_layout']))
if found_layout['target_layout']:
prep.output(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
prep.build()
def apply_offline_transformations(input_model: str, argv: argparse.Namespace):
# This variable is only needed by GenerateMappingFile transformation
# to produce correct mapping
extract_names = framework in ['tf', 'mxnet', 'kaldi']
extract_names = argv.framework in ['tf', 'mxnet', 'kaldi']
from openvino.offline_transformations_pybind import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module
from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error
@ -57,24 +104,14 @@ def apply_offline_transformations(input_model: str, framework: str, transforms:
func = read_model(input_model + "_tmp.xml")
apply_user_transformations(func, transforms)
add_layouts(func, argv) # TODO: replace with preprocessing
apply_user_transformations(func, parse_transform(argv.transform))
apply_moc_transformations(func)
if compress_fp16:
if "compress_fp16" in argv and argv.compress_fp16:
compress_model(func)
serialize(func, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8'))
path_to_mapping = input_model + ".mapping"
generate_mapping_file(func, path_to_mapping.encode('utf-8'), extract_names)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_model")
parser.add_argument("--framework")
parser.add_argument("--transform")
parser.add_argument("--compress_fp16", action='store_true')
args = parser.parse_args()
apply_offline_transformations(args.input_model, args.framework, parse_transform(args.transform), args.compress_fp16)

View File

@ -6,7 +6,6 @@ import datetime
import logging as log
import os
import platform
import subprocess
import sys
import traceback
from collections import OrderedDict
@ -28,10 +27,10 @@ from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_
from openvino.tools.mo.pipeline.common import prepare_emit_ir, get_ir_version
from openvino.tools.mo.pipeline.unified import unified_pipeline
from openvino.tools.mo.utils import import_extensions
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_model_name, \
get_common_cli_options, get_caffe_cli_options, get_tf_cli_options, get_mxnet_cli_options, get_kaldi_cli_options, \
get_onnx_cli_options, get_mean_scale_dictionary, parse_tuple_pairs, get_freeze_placeholder_values, get_meta_info, \
parse_transform, check_available_transforms
from openvino.tools.mo.utils.cli_parser import check_available_transforms, get_caffe_cli_options, \
get_common_cli_options, get_freeze_placeholder_values, get_kaldi_cli_options, get_layout_values, \
get_mean_scale_dictionary, get_meta_info, get_model_name, get_mxnet_cli_options, get_onnx_cli_options, \
get_placeholder_shapes, get_tf_cli_options, get_tuple_values, parse_transform, parse_tuple_pairs
from openvino.tools.mo.utils.error import Error, FrameworkError
from openvino.tools.mo.utils.find_ie_version import find_ie_version
from openvino.tools.mo.utils.get_ov_update_message import get_ov_update_message
@ -268,6 +267,7 @@ def arguments_post_parsing(argv: argparse.Namespace):
scale_values = parse_tuple_pairs(argv.scale_values)
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
argv.mean_scale_values = mean_scale
argv.layout_values = get_layout_values(argv.layout, argv.source_layout, argv.target_layout)
if not os.path.exists(argv.output_dir):
try:
@ -360,22 +360,14 @@ def emit_ir(graph: Graph, argv: argparse.Namespace):
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
return_code = "not executed"
# This try-except is additional reinsurance that the IE
# dependency search does not break the MO pipeline
try:
if not argv.legacy_ir_generation:
path_to_offline_transformations = os.path.join(os.path.realpath(os.path.dirname(__file__)), 'back',
'offline_transformations.py')
cmd = [sys.executable, path_to_offline_transformations,
"--input_model", orig_model_name,
"--framework", argv.framework,
"--transform", argv.transform]
from openvino.tools.mo.back.offline_transformations import apply_offline_transformations
apply_offline_transformations(orig_model_name, argv)
if "compress_fp16" in argv and argv.compress_fp16:
cmd += ["--compress_fp16"]
# restore data_type cmd parameter
argv.data_type = 'FP16'
status = subprocess.run(cmd, env=os.environ)
return_code = status.returncode
return_code = 0
except Exception as e:
return_code = "failed"
log.error(e)

View File

@ -309,6 +309,28 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
'The exact meaning and order ' +
'of channels depend on how the original model was trained.',
default=())
common_group.add_argument('--source_layout',
help='Layout of the input or output of the model in the framework. Layout can'
' be specified in the short form, e.g. nhwc, or in complex form, e.g. [n,h,w,c].'
' Example for many names: '
'in_name1([n,h,w,c]),in_name2(nc),out_name1(n),out_name2(nc). Layout can be '
'partially defined, "?" can be used to specify undefined layout for one dimension, '
'"..." can be used to specify undefined layout for multiple dimensions, for example '
'?c??, nc..., n...c, etc.',
default=())
common_group.add_argument('--target_layout',
help='Same as --source_layout, but specifies target layout that will be in the model '
'after processing by ModelOptimizer.',
default=())
common_group.add_argument('--layout',
help='Combination of --source_layout and --target_layout. Can\'t be used with either of '
'them. If model has one input it is sufficient to specify layout of this input, for'
' example --layout nhwc. To specify layouts of many tensors, names must be provided,'
' for example: --layout name1(nchw),name2(nc). It is possible to instruct '
'ModelOptimizer to change layout, for example: '
'--layout name1(nhwc->nchw),name2(cn->nc). Also "*" in long layout form can be used'
' to fuse dimensions, for example [n,c,...]->[n*c,…].',
default=())
# TODO: isn't it a weights precision type
common_group.add_argument('--data_type',
help='Data type for all intermediate tensors and weights. ' +
@ -417,6 +439,9 @@ def get_common_cli_options(model_name):
d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model']
d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model']
d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model']
d['source_layout'] = ['- Source layout', lambda x: x if x else 'Not specified']
d['target_layout'] = ['- Target layout', lambda x: x if x else 'Not specified']
d['layout'] = ['- Layout', lambda x: x if x else 'Not specified']
d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified']
d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified']
d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified']
@ -835,6 +860,137 @@ def parse_input_value(input_value: str):
return node_name, shape, value, data_type
def split_str_avoiding_square_brackets(s: str) -> list:
"""
Splits a string by comma, but skips commas inside square brackets.
:param s: string to split
:return: list of strings split by comma
"""
res = list()
skipping = 0
last_idx = 0
for i, c in enumerate(s):
if c == '[':
skipping += 1
elif c == ']':
skipping -= 1
elif c == ',' and skipping == 0:
res.append(s[last_idx:i])
last_idx = i + 1
res.append(s[last_idx:])
return res
def split_layouts_by_arrow(s: str) -> tuple:
"""
Splits a layout string by first arrow (->).
:param s: string to split
:return: tuple containing source and target layouts
"""
arrow = s.find('->')
if arrow != -1:
source_layout = s[:arrow]
target_layout = s[arrow + 2:]
if source_layout == '':
source_layout = None
if target_layout == '':
target_layout = None
return source_layout, target_layout
else:
return s, None
def validate_layout(layout: str):
"""
Checks if layout is of valid format.
:param layout: string containing layout
:raises: if layout is incorrect
"""
valid_layout_re = re.compile(r'\[?[^\[\]\(\)\s]*\]?')
if layout and not valid_layout_re.fullmatch(layout):
raise Error('Invalid layout parsed: {}'.format(layout))
def write_found_layout(name: str, found_layout: str, parsed: dict, dest: str = None):
"""
Writes found layout data to the 'parsed' dict.
:param name: name of the node to add layout
:param found_layout: string containing layout for the node
:param parsed: dict where result will be stored
:param dest: type of the command line:
* 'source' is --source_layout
* 'target' is --target_layout
* None is --layout
"""
s_layout = None
t_layout = None
if name in parsed:
s_layout = parsed[name]['source_layout']
t_layout = parsed[name]['target_layout']
if dest == 'source':
s_layout = found_layout
elif dest == 'target':
t_layout = found_layout
else:
s_layout, t_layout = split_layouts_by_arrow(found_layout)
validate_layout(s_layout)
validate_layout(t_layout)
parsed[name] = {'source_layout': s_layout, 'target_layout': t_layout}
def parse_layouts_by_destination(s: str, parsed: dict, dest: str = None) -> None:
"""
Parses layout command line to get all names and layouts from it. Adds all found data in the 'parsed' dict.
:param s: string to parse
:param parsed: dict where result will be stored
:param dest: type of the command line:
* 'source' is --source_layout
* 'target' is --target_layout
* None is --layout
"""
list_s = split_str_avoiding_square_brackets(s)
if len(list_s) == 1 and (list_s[0][-1] not in ')]' or (list_s[0][0] == '[' and list_s[0][-1] == ']')):
# single layout case
write_found_layout('', list_s[0], parsed, dest)
else:
for layout_str in list_s:
# case for: "name1(nhwc->[n,c,h,w])"
p1 = re.compile(r'(\S+)\((\S+)\)')
m1 = p1.match(layout_str)
# case for: "name1[n,h,w,c]->[n,c,h,w]"
p2 = re.compile(r'(\S+)(\[\S*\])')
m2 = p2.match(layout_str)
if m1:
found_g = m1.groups()
elif m2:
found_g = m2.groups()
else:
raise Error("More then one layout provided for --{}layout without providing name.".format(
dest + '_' if dest else ''))
write_found_layout(found_g[0], found_g[1], parsed, dest)
def get_layout_values(argv_layout: str = '', argv_source_layout: str = '', argv_target_layout: str = ''):
"""
Parses layout string.
:param argv_layout: string with a list of layouts passed as a --layout.
:param argv_source_layout: string with a list of layouts passed as a --source_layout.
:param argv_target_layout: string with a list of layouts passed as a --target_layout.
:return: dict with names and layouts associated
"""
if argv_layout and (argv_source_layout or argv_target_layout):
raise Error("--layout is used as well as --source_layout and/or --target_layout which is not allowed, please "
"use one of them.")
res = {}
if argv_layout:
parse_layouts_by_destination(argv_layout, res)
if argv_source_layout:
parse_layouts_by_destination(argv_source_layout, res, 'source')
if argv_target_layout:
parse_layouts_by_destination(argv_target_layout, res, 'target')
return res
def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_value: str):
"""
Parses values for placeholder freezing and input node names

View File

@ -56,6 +56,9 @@ def replaceArgsHelper(log_level='DEBUG',
batch=None,
mean_values=None,
scale_values=None,
layout=None,
source_layout=None,
target_layout=None,
output_dir='.',
freeze_placeholder_with_value=None):
return argparse.Namespace(
@ -72,6 +75,9 @@ def replaceArgsHelper(log_level='DEBUG',
batch=batch,
mean_values=mean_values,
scale_values=scale_values,
layout=layout,
source_layout=source_layout,
target_layout=target_layout,
output_dir=output_dir,
freeze_placeholder_with_value=freeze_placeholder_with_value,
use_legacy_frontend=None,

View File

@ -12,9 +12,10 @@ from unittest.mock import patch
import numpy as np
import numpy.testing as npt
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \
get_model_name, \
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values
from openvino.tools.mo.utils.error import Error
@ -380,6 +381,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
def test_input_without_values(self):
self.assertRaises(Error, parse_tuple_pairs, "input1,input2")
class TestSingleTupleParsing(unittest.TestCase):
def test_get_values_ideal(self):
values = "(1.11, 22.22, 333.333)"
@ -476,7 +478,8 @@ class TestShapesParsing(unittest.TestCase):
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys():
@ -492,7 +495,8 @@ class TestShapesParsing(unittest.TestCase):
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys():
@ -510,8 +514,10 @@ class TestShapesParsing(unittest.TestCase):
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input,
argv_freeze_placeholder_with_value)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'], ),
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
@ -772,6 +778,7 @@ class TestShapesParsing(unittest.TestCase):
input_shapes = "(12,4,1),(4,-6,8)"
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
class TestModelNameParsing(unittest.TestCase):
def test_model_name_ideal(self):
model_name = '/home/models/mymodel.caffemodel'
@ -923,9 +930,9 @@ class TransformChecker(unittest.TestCase):
def test_multiple_passes_with_args2(self):
self.assertEqual(parse_transform("LowLatency2[use_const_initializer=True,False],DummyPass1,"
"DummyPass2[types=ReLU,PReLU;values=1,2,3]"),
[("LowLatency2", {"use_const_initializer": [True, False]}),
("DummyPass1", {}),
("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1,2,3]})])
[("LowLatency2", {"use_const_initializer": [True, False]}),
("DummyPass1", {}),
("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1, 2, 3]})])
def test_multiple_passes_no_args(self):
self.assertEqual(parse_transform("DummyPass,LowLatency22"),
@ -967,3 +974,297 @@ class TransformChecker(unittest.TestCase):
def test_check_dummy_pass_is_available(self, available_transformations):
available_transformations.return_value = {"LowLatency2": None}
self.assertRaises(Error, check_available_transforms, [("DummyPass", "")])
class TestLayoutParsing(unittest.TestCase):
def test_get_layout_1(self):
argv_layout = "name1([n,h,w,c]),name2([n,h,w,c]->[n,c,h,w])"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_2(self):
argv_layout = "name1(nhwc),name2(nhwc->nchw)"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_3(self):
argv_layout = "name1(n...c),name2(n...c->nc...)"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': 'n...c', 'target_layout': None},
'name2': {'source_layout': 'n...c', 'target_layout': 'nc...'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_4(self):
argv_layout = "nhwc"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_5(self):
argv_layout = "[n,h,w,c]"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_6(self):
argv_layout = "nhwc->nchw"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_7(self):
argv_layout = "[n,h,w,c]->[n,c,h,w]"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_scalar(self):
argv_layout = "name1(nhwc),name2([])"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_1(self):
argv_source_layout = "[n,h,w,c]"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_2(self):
argv_source_layout = "nhwc"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_3(self):
argv_source_layout = "name1(nhwc),name2(nchw)"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': 'nchw', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_4(self):
argv_source_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_5(self):
argv_source_layout = "name1(nhwc),name2([n,c,h,w])"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_6(self):
argv_source_layout = "name1(nhwc),name2[n,c,h,w]"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_scalar(self):
argv_source_layout = "name1(nhwc),name2([])"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_1(self):
argv_target_layout = "[n,h,w,c]"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': None, 'target_layout': '[n,h,w,c]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_2(self):
argv_target_layout = "nhwc"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': None, 'target_layout': 'nhwc'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_3(self):
argv_target_layout = "name1(nhwc),name2(nchw)"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_4(self):
argv_target_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': '[n,h,w,c]'},
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_5(self):
argv_target_layout = "name1(nhwc),name2([n,c,h,w])"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_6(self):
argv_target_layout = "name1(nhwc),name2[n,c,h,w]"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_scalar(self):
argv_target_layout = "name1(nhwc),name2[]"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': '[]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_1(self):
argv_source_layout = "[n,h,w,c]"
argv_target_layout = "[n,c,h,w]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_2(self):
argv_source_layout = "nhwc"
argv_target_layout = "nchw"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_3(self):
argv_source_layout = "name1(nhwc),name2(nhwc)"
argv_target_layout = "name1(nchw),name2(nchw)"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_4(self):
argv_source_layout = "name1([n,h,w,c]),name2([n,h,w,c])"
argv_target_layout = "name1([n,c,h,w]),name2([n,c,h,w])"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'},
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_5(self):
argv_source_layout = "name1(nhwc),name2[n,h,w,c]"
argv_target_layout = "name1(nchw),name2[n,c,h,w]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_scalar(self):
argv_source_layout = "name1(nhwc),name2[]"
argv_target_layout = "name1(nchw),name2[]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
'name2': {'source_layout': '[]', 'target_layout': '[]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_raises_if_layout_and_source_layout_provided(self):
argv_layout = "nhwc"
argv_source_layout = "nhwc"
with self.assertRaises(Error):
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout)
def test_get_layout_raises_if_layout_and_target_layout_provided(self):
argv_layout = "nhwc->nchw"
argv_target_layout = "nchw"
with self.assertRaises(Error):
get_layout_values(argv_layout=argv_layout, argv_target_layout=argv_target_layout)
def test_get_layout_raises_if_layout_with_source_and_target_layout_provided(self):
argv_layout = "nhwc->nchw"
argv_source_layout = "nhwc"
argv_target_layout = "nchw"
with self.assertRaises(Error):
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout,
argv_target_layout=argv_target_layout)
def test_get_layout_raises_incorrect_format(self):
argv_layout = "name[n,h,w,c]->nchw"
with self.assertRaises(Error):
res = get_layout_values(argv_layout=argv_layout)
print(res)
def test_get_layout_raises_multiple_layouts_without_names(self):
argv_layout = "nhwc->nchw,nhwc->nchw"
with self.assertRaises(Error):
res = get_layout_values(argv_layout=argv_layout)
print(res)
def test_get_layout_raises_multiple_layouts_without_names_source_layout(self):
argv_source_layout = "nhwc,nhwc"
with self.assertRaises(Error):
res = get_layout_values(argv_source_layout=argv_source_layout)
print(res)
def test_get_layout_raises_multiple_layouts_without_names_target_layout(self):
argv_target_layout = "nchw,nchw"
with self.assertRaises(Error):
res = get_layout_values(argv_target_layout=argv_target_layout)
print(res)