Add layout commands in MO (#8829)

* Add layout support in MO

* Apply review feedback
This commit is contained in:
Maxim Vafin 2021-12-13 13:57:19 +03:00 committed by GitHub
parent 2ff4ef2e4f
commit 51947eeb3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 533 additions and 41 deletions

View File

@ -37,10 +37,57 @@ def compress_model(func: object):
from openvino.offline_transformations_pybind import compress_model_transformation # pylint: disable=import-error,no-name-in-module from openvino.offline_transformations_pybind import compress_model_transformation # pylint: disable=import-error,no-name-in-module
compress_model_transformation(func) compress_model_transformation(func)
def apply_offline_transformations(input_model: str, framework: str, transforms: list, compress_fp16=False):
def add_layouts(ov_function, argv: argparse.Namespace):
from openvino.preprocess import PrePostProcessor # pylint: disable=no-name-in-module,import-error
from openvino.runtime import Layout # pylint: disable=import-error,no-name-in-module
prep = PrePostProcessor(ov_function)
layout_values = argv.layout_values
if '' in layout_values:
if len(ov_function.inputs) == 1:
layout_values = {
list(ov_function.input().get_tensor().get_names())[0]: {
'source_layout': layout_values[''].get('source_layout'),
'target_layout': layout_values[''].get('target_layout')
}
}
else:
input_names = [list(ov_input.get_tensor().get_names())[0] for ov_input in ov_function.inputs]
raise Error('Layout without name can be specified for models with only one input, '
'but provided model has {} inputs: \'{}\'. '
'Please specify explicitly input/output name for --layout option'
.format(len(input_names), input_names))
set_layout_names = set(layout_values.keys())
for idx, ov_input in enumerate(ov_function.inputs):
found = set.intersection(set(ov_input.get_tensor().get_names()), set_layout_names)
assert len(found) <= 1, 'More then one name point to the same node'
if len(found) == 1:
node_name = list(found)[0]
found_layout = layout_values[node_name]
if found_layout['source_layout']:
prep.input(node_name).network().set_layout(Layout(found_layout['source_layout']))
if found_layout['target_layout']:
prep.input(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
for idx, ov_output in enumerate(ov_function.outputs):
found = set.intersection(set(ov_output.get_tensor().get_names()), set_layout_names)
assert len(found) <= 1, 'More then one name point to the same node'
if len(found) == 1:
node_name = list(found)[0]
found_layout = layout_values[node_name]
if found_layout['source_layout']:
prep.output(node_name).network().set_layout(Layout(found_layout['source_layout']))
if found_layout['target_layout']:
prep.output(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
prep.build()
def apply_offline_transformations(input_model: str, argv: argparse.Namespace):
# This variable is only needed by GenerateMappingFile transformation # This variable is only needed by GenerateMappingFile transformation
# to produce correct mapping # to produce correct mapping
extract_names = framework in ['tf', 'mxnet', 'kaldi'] extract_names = argv.framework in ['tf', 'mxnet', 'kaldi']
from openvino.offline_transformations_pybind import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module from openvino.offline_transformations_pybind import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module
from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error
@ -57,24 +104,14 @@ def apply_offline_transformations(input_model: str, framework: str, transforms:
func = read_model(input_model + "_tmp.xml") func = read_model(input_model + "_tmp.xml")
apply_user_transformations(func, transforms) add_layouts(func, argv) # TODO: replace with preprocessing
apply_user_transformations(func, parse_transform(argv.transform))
apply_moc_transformations(func) apply_moc_transformations(func)
if compress_fp16: if "compress_fp16" in argv and argv.compress_fp16:
compress_model(func) compress_model(func)
serialize(func, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8')) serialize(func, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8'))
path_to_mapping = input_model + ".mapping" path_to_mapping = input_model + ".mapping"
generate_mapping_file(func, path_to_mapping.encode('utf-8'), extract_names) generate_mapping_file(func, path_to_mapping.encode('utf-8'), extract_names)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_model")
parser.add_argument("--framework")
parser.add_argument("--transform")
parser.add_argument("--compress_fp16", action='store_true')
args = parser.parse_args()
apply_offline_transformations(args.input_model, args.framework, parse_transform(args.transform), args.compress_fp16)

View File

@ -6,7 +6,6 @@ import datetime
import logging as log import logging as log
import os import os
import platform import platform
import subprocess
import sys import sys
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
@ -28,10 +27,10 @@ from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_
from openvino.tools.mo.pipeline.common import prepare_emit_ir, get_ir_version from openvino.tools.mo.pipeline.common import prepare_emit_ir, get_ir_version
from openvino.tools.mo.pipeline.unified import unified_pipeline from openvino.tools.mo.pipeline.unified import unified_pipeline
from openvino.tools.mo.utils import import_extensions from openvino.tools.mo.utils import import_extensions
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_model_name, \ from openvino.tools.mo.utils.cli_parser import check_available_transforms, get_caffe_cli_options, \
get_common_cli_options, get_caffe_cli_options, get_tf_cli_options, get_mxnet_cli_options, get_kaldi_cli_options, \ get_common_cli_options, get_freeze_placeholder_values, get_kaldi_cli_options, get_layout_values, \
get_onnx_cli_options, get_mean_scale_dictionary, parse_tuple_pairs, get_freeze_placeholder_values, get_meta_info, \ get_mean_scale_dictionary, get_meta_info, get_model_name, get_mxnet_cli_options, get_onnx_cli_options, \
parse_transform, check_available_transforms get_placeholder_shapes, get_tf_cli_options, get_tuple_values, parse_transform, parse_tuple_pairs
from openvino.tools.mo.utils.error import Error, FrameworkError from openvino.tools.mo.utils.error import Error, FrameworkError
from openvino.tools.mo.utils.find_ie_version import find_ie_version from openvino.tools.mo.utils.find_ie_version import find_ie_version
from openvino.tools.mo.utils.get_ov_update_message import get_ov_update_message from openvino.tools.mo.utils.get_ov_update_message import get_ov_update_message
@ -268,6 +267,7 @@ def arguments_post_parsing(argv: argparse.Namespace):
scale_values = parse_tuple_pairs(argv.scale_values) scale_values = parse_tuple_pairs(argv.scale_values)
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input) mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
argv.mean_scale_values = mean_scale argv.mean_scale_values = mean_scale
argv.layout_values = get_layout_values(argv.layout, argv.source_layout, argv.target_layout)
if not os.path.exists(argv.output_dir): if not os.path.exists(argv.output_dir):
try: try:
@ -360,22 +360,14 @@ def emit_ir(graph: Graph, argv: argparse.Namespace):
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name)) orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
return_code = "not executed" return_code = "not executed"
# This try-except is additional reinsurance that the IE
# dependency search does not break the MO pipeline
try: try:
if not argv.legacy_ir_generation: if not argv.legacy_ir_generation:
path_to_offline_transformations = os.path.join(os.path.realpath(os.path.dirname(__file__)), 'back', from openvino.tools.mo.back.offline_transformations import apply_offline_transformations
'offline_transformations.py') apply_offline_transformations(orig_model_name, argv)
cmd = [sys.executable, path_to_offline_transformations,
"--input_model", orig_model_name,
"--framework", argv.framework,
"--transform", argv.transform]
if "compress_fp16" in argv and argv.compress_fp16: if "compress_fp16" in argv and argv.compress_fp16:
cmd += ["--compress_fp16"]
# restore data_type cmd parameter # restore data_type cmd parameter
argv.data_type = 'FP16' argv.data_type = 'FP16'
status = subprocess.run(cmd, env=os.environ) return_code = 0
return_code = status.returncode
except Exception as e: except Exception as e:
return_code = "failed" return_code = "failed"
log.error(e) log.error(e)

View File

@ -309,6 +309,28 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
'The exact meaning and order ' + 'The exact meaning and order ' +
'of channels depend on how the original model was trained.', 'of channels depend on how the original model was trained.',
default=()) default=())
common_group.add_argument('--source_layout',
help='Layout of the input or output of the model in the framework. Layout can'
' be specified in the short form, e.g. nhwc, or in complex form, e.g. [n,h,w,c].'
' Example for many names: '
'in_name1([n,h,w,c]),in_name2(nc),out_name1(n),out_name2(nc). Layout can be '
'partially defined, "?" can be used to specify undefined layout for one dimension, '
'"..." can be used to specify undefined layout for multiple dimensions, for example '
'?c??, nc..., n...c, etc.',
default=())
common_group.add_argument('--target_layout',
help='Same as --source_layout, but specifies target layout that will be in the model '
'after processing by ModelOptimizer.',
default=())
common_group.add_argument('--layout',
help='Combination of --source_layout and --target_layout. Can\'t be used with either of '
'them. If model has one input it is sufficient to specify layout of this input, for'
' example --layout nhwc. To specify layouts of many tensors, names must be provided,'
' for example: --layout name1(nchw),name2(nc). It is possible to instruct '
'ModelOptimizer to change layout, for example: '
'--layout name1(nhwc->nchw),name2(cn->nc). Also "*" in long layout form can be used'
' to fuse dimensions, for example [n,c,...]->[n*c,…].',
default=())
# TODO: isn't it a weights precision type # TODO: isn't it a weights precision type
common_group.add_argument('--data_type', common_group.add_argument('--data_type',
help='Data type for all intermediate tensors and weights. ' + help='Data type for all intermediate tensors and weights. ' +
@ -417,6 +439,9 @@ def get_common_cli_options(model_name):
d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model'] d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model']
d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model'] d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model']
d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model'] d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model']
d['source_layout'] = ['- Source layout', lambda x: x if x else 'Not specified']
d['target_layout'] = ['- Target layout', lambda x: x if x else 'Not specified']
d['layout'] = ['- Layout', lambda x: x if x else 'Not specified']
d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified'] d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified']
d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified'] d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified']
d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified'] d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified']
@ -835,6 +860,137 @@ def parse_input_value(input_value: str):
return node_name, shape, value, data_type return node_name, shape, value, data_type
def split_str_avoiding_square_brackets(s: str) -> list:
"""
Splits a string by comma, but skips commas inside square brackets.
:param s: string to split
:return: list of strings split by comma
"""
res = list()
skipping = 0
last_idx = 0
for i, c in enumerate(s):
if c == '[':
skipping += 1
elif c == ']':
skipping -= 1
elif c == ',' and skipping == 0:
res.append(s[last_idx:i])
last_idx = i + 1
res.append(s[last_idx:])
return res
def split_layouts_by_arrow(s: str) -> tuple:
"""
Splits a layout string by first arrow (->).
:param s: string to split
:return: tuple containing source and target layouts
"""
arrow = s.find('->')
if arrow != -1:
source_layout = s[:arrow]
target_layout = s[arrow + 2:]
if source_layout == '':
source_layout = None
if target_layout == '':
target_layout = None
return source_layout, target_layout
else:
return s, None
def validate_layout(layout: str):
"""
Checks if layout is of valid format.
:param layout: string containing layout
:raises: if layout is incorrect
"""
valid_layout_re = re.compile(r'\[?[^\[\]\(\)\s]*\]?')
if layout and not valid_layout_re.fullmatch(layout):
raise Error('Invalid layout parsed: {}'.format(layout))
def write_found_layout(name: str, found_layout: str, parsed: dict, dest: str = None):
"""
Writes found layout data to the 'parsed' dict.
:param name: name of the node to add layout
:param found_layout: string containing layout for the node
:param parsed: dict where result will be stored
:param dest: type of the command line:
* 'source' is --source_layout
* 'target' is --target_layout
* None is --layout
"""
s_layout = None
t_layout = None
if name in parsed:
s_layout = parsed[name]['source_layout']
t_layout = parsed[name]['target_layout']
if dest == 'source':
s_layout = found_layout
elif dest == 'target':
t_layout = found_layout
else:
s_layout, t_layout = split_layouts_by_arrow(found_layout)
validate_layout(s_layout)
validate_layout(t_layout)
parsed[name] = {'source_layout': s_layout, 'target_layout': t_layout}
def parse_layouts_by_destination(s: str, parsed: dict, dest: str = None) -> None:
"""
Parses layout command line to get all names and layouts from it. Adds all found data in the 'parsed' dict.
:param s: string to parse
:param parsed: dict where result will be stored
:param dest: type of the command line:
* 'source' is --source_layout
* 'target' is --target_layout
* None is --layout
"""
list_s = split_str_avoiding_square_brackets(s)
if len(list_s) == 1 and (list_s[0][-1] not in ')]' or (list_s[0][0] == '[' and list_s[0][-1] == ']')):
# single layout case
write_found_layout('', list_s[0], parsed, dest)
else:
for layout_str in list_s:
# case for: "name1(nhwc->[n,c,h,w])"
p1 = re.compile(r'(\S+)\((\S+)\)')
m1 = p1.match(layout_str)
# case for: "name1[n,h,w,c]->[n,c,h,w]"
p2 = re.compile(r'(\S+)(\[\S*\])')
m2 = p2.match(layout_str)
if m1:
found_g = m1.groups()
elif m2:
found_g = m2.groups()
else:
raise Error("More then one layout provided for --{}layout without providing name.".format(
dest + '_' if dest else ''))
write_found_layout(found_g[0], found_g[1], parsed, dest)
def get_layout_values(argv_layout: str = '', argv_source_layout: str = '', argv_target_layout: str = ''):
"""
Parses layout string.
:param argv_layout: string with a list of layouts passed as a --layout.
:param argv_source_layout: string with a list of layouts passed as a --source_layout.
:param argv_target_layout: string with a list of layouts passed as a --target_layout.
:return: dict with names and layouts associated
"""
if argv_layout and (argv_source_layout or argv_target_layout):
raise Error("--layout is used as well as --source_layout and/or --target_layout which is not allowed, please "
"use one of them.")
res = {}
if argv_layout:
parse_layouts_by_destination(argv_layout, res)
if argv_source_layout:
parse_layouts_by_destination(argv_source_layout, res, 'source')
if argv_target_layout:
parse_layouts_by_destination(argv_target_layout, res, 'target')
return res
def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_value: str): def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_value: str):
""" """
Parses values for placeholder freezing and input node names Parses values for placeholder freezing and input node names

View File

@ -56,6 +56,9 @@ def replaceArgsHelper(log_level='DEBUG',
batch=None, batch=None,
mean_values=None, mean_values=None,
scale_values=None, scale_values=None,
layout=None,
source_layout=None,
target_layout=None,
output_dir='.', output_dir='.',
freeze_placeholder_with_value=None): freeze_placeholder_with_value=None):
return argparse.Namespace( return argparse.Namespace(
@ -72,6 +75,9 @@ def replaceArgsHelper(log_level='DEBUG',
batch=batch, batch=batch,
mean_values=mean_values, mean_values=mean_values,
scale_values=scale_values, scale_values=scale_values,
layout=layout,
source_layout=source_layout,
target_layout=target_layout,
output_dir=output_dir, output_dir=output_dir,
freeze_placeholder_with_value=freeze_placeholder_with_value, freeze_placeholder_with_value=freeze_placeholder_with_value,
use_legacy_frontend=None, use_legacy_frontend=None,

View File

@ -12,9 +12,10 @@ from unittest.mock import patch
import numpy as np import numpy as np
import numpy.testing as npt import numpy.testing as npt
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \ from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \
get_model_name, \
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \ parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values
from openvino.tools.mo.utils.error import Error from openvino.tools.mo.utils.error import Error
@ -380,6 +381,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
def test_input_without_values(self): def test_input_without_values(self):
self.assertRaises(Error, parse_tuple_pairs, "input1,input2") self.assertRaises(Error, parse_tuple_pairs, "input1,input2")
class TestSingleTupleParsing(unittest.TestCase): class TestSingleTupleParsing(unittest.TestCase):
def test_get_values_ideal(self): def test_get_values_ideal(self):
values = "(1.11, 22.22, 333.333)" values = "(1.11, 22.22, 333.333)"
@ -476,7 +478,8 @@ class TestShapesParsing(unittest.TestCase):
for i in exp_res.keys(): for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3" input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
@ -492,7 +495,8 @@ class TestShapesParsing(unittest.TestCase):
for i in exp_res.keys(): for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3" input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
@ -510,8 +514,10 @@ class TestShapesParsing(unittest.TestCase):
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input,
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],), argv_freeze_placeholder_with_value)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'], ),
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])} 'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
input_node_names_ref = "inp1,inp2,inp3" input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys()))) self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
@ -772,6 +778,7 @@ class TestShapesParsing(unittest.TestCase):
input_shapes = "(12,4,1),(4,-6,8)" input_shapes = "(12,4,1),(4,-6,8)"
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes) self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
class TestModelNameParsing(unittest.TestCase): class TestModelNameParsing(unittest.TestCase):
def test_model_name_ideal(self): def test_model_name_ideal(self):
model_name = '/home/models/mymodel.caffemodel' model_name = '/home/models/mymodel.caffemodel'
@ -923,9 +930,9 @@ class TransformChecker(unittest.TestCase):
def test_multiple_passes_with_args2(self): def test_multiple_passes_with_args2(self):
self.assertEqual(parse_transform("LowLatency2[use_const_initializer=True,False],DummyPass1," self.assertEqual(parse_transform("LowLatency2[use_const_initializer=True,False],DummyPass1,"
"DummyPass2[types=ReLU,PReLU;values=1,2,3]"), "DummyPass2[types=ReLU,PReLU;values=1,2,3]"),
[("LowLatency2", {"use_const_initializer": [True, False]}), [("LowLatency2", {"use_const_initializer": [True, False]}),
("DummyPass1", {}), ("DummyPass1", {}),
("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1,2,3]})]) ("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1, 2, 3]})])
def test_multiple_passes_no_args(self): def test_multiple_passes_no_args(self):
self.assertEqual(parse_transform("DummyPass,LowLatency22"), self.assertEqual(parse_transform("DummyPass,LowLatency22"),
@ -967,3 +974,297 @@ class TransformChecker(unittest.TestCase):
def test_check_dummy_pass_is_available(self, available_transformations): def test_check_dummy_pass_is_available(self, available_transformations):
available_transformations.return_value = {"LowLatency2": None} available_transformations.return_value = {"LowLatency2": None}
self.assertRaises(Error, check_available_transforms, [("DummyPass", "")]) self.assertRaises(Error, check_available_transforms, [("DummyPass", "")])
class TestLayoutParsing(unittest.TestCase):
def test_get_layout_1(self):
argv_layout = "name1([n,h,w,c]),name2([n,h,w,c]->[n,c,h,w])"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_2(self):
argv_layout = "name1(nhwc),name2(nhwc->nchw)"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_3(self):
argv_layout = "name1(n...c),name2(n...c->nc...)"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': 'n...c', 'target_layout': None},
'name2': {'source_layout': 'n...c', 'target_layout': 'nc...'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_4(self):
argv_layout = "nhwc"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_5(self):
argv_layout = "[n,h,w,c]"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_6(self):
argv_layout = "nhwc->nchw"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_7(self):
argv_layout = "[n,h,w,c]->[n,c,h,w]"
result = get_layout_values(argv_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_scalar(self):
argv_layout = "name1(nhwc),name2([])"
result = get_layout_values(argv_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_1(self):
argv_source_layout = "[n,h,w,c]"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_2(self):
argv_source_layout = "nhwc"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_3(self):
argv_source_layout = "name1(nhwc),name2(nchw)"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': 'nchw', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_4(self):
argv_source_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_5(self):
argv_source_layout = "name1(nhwc),name2([n,c,h,w])"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_6(self):
argv_source_layout = "name1(nhwc),name2[n,c,h,w]"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_layout_scalar(self):
argv_source_layout = "name1(nhwc),name2([])"
result = get_layout_values(argv_source_layout=argv_source_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
'name2': {'source_layout': '[]', 'target_layout': None}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_1(self):
argv_target_layout = "[n,h,w,c]"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': None, 'target_layout': '[n,h,w,c]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_2(self):
argv_target_layout = "nhwc"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': None, 'target_layout': 'nhwc'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_3(self):
argv_target_layout = "name1(nhwc),name2(nchw)"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_4(self):
argv_target_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': '[n,h,w,c]'},
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_5(self):
argv_target_layout = "name1(nhwc),name2([n,c,h,w])"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_6(self):
argv_target_layout = "name1(nhwc),name2[n,c,h,w]"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_target_layout_scalar(self):
argv_target_layout = "name1(nhwc),name2[]"
result = get_layout_values(argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
'name2': {'source_layout': None, 'target_layout': '[]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_1(self):
argv_source_layout = "[n,h,w,c]"
argv_target_layout = "[n,c,h,w]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_2(self):
argv_source_layout = "nhwc"
argv_target_layout = "nchw"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_3(self):
argv_source_layout = "name1(nhwc),name2(nhwc)"
argv_target_layout = "name1(nchw),name2(nchw)"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_4(self):
argv_source_layout = "name1([n,h,w,c]),name2([n,h,w,c])"
argv_target_layout = "name1([n,c,h,w]),name2([n,c,h,w])"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'},
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_5(self):
argv_source_layout = "name1(nhwc),name2[n,h,w,c]"
argv_target_layout = "name1(nchw),name2[n,c,h,w]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_scalar(self):
argv_source_layout = "name1(nhwc),name2[]"
argv_target_layout = "name1(nchw),name2[]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
'name2': {'source_layout': '[]', 'target_layout': '[]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_layout_raises_if_layout_and_source_layout_provided(self):
argv_layout = "nhwc"
argv_source_layout = "nhwc"
with self.assertRaises(Error):
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout)
def test_get_layout_raises_if_layout_and_target_layout_provided(self):
argv_layout = "nhwc->nchw"
argv_target_layout = "nchw"
with self.assertRaises(Error):
get_layout_values(argv_layout=argv_layout, argv_target_layout=argv_target_layout)
def test_get_layout_raises_if_layout_with_source_and_target_layout_provided(self):
argv_layout = "nhwc->nchw"
argv_source_layout = "nhwc"
argv_target_layout = "nchw"
with self.assertRaises(Error):
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout,
argv_target_layout=argv_target_layout)
def test_get_layout_raises_incorrect_format(self):
argv_layout = "name[n,h,w,c]->nchw"
with self.assertRaises(Error):
res = get_layout_values(argv_layout=argv_layout)
print(res)
def test_get_layout_raises_multiple_layouts_without_names(self):
argv_layout = "nhwc->nchw,nhwc->nchw"
with self.assertRaises(Error):
res = get_layout_values(argv_layout=argv_layout)
print(res)
def test_get_layout_raises_multiple_layouts_without_names_source_layout(self):
argv_source_layout = "nhwc,nhwc"
with self.assertRaises(Error):
res = get_layout_values(argv_source_layout=argv_source_layout)
print(res)
def test_get_layout_raises_multiple_layouts_without_names_target_layout(self):
argv_target_layout = "nchw,nchw"
with self.assertRaises(Error):
res = get_layout_values(argv_target_layout=argv_target_layout)
print(res)