diff --git a/tools/mo/openvino/tools/mo/back/offline_transformations.py b/tools/mo/openvino/tools/mo/back/offline_transformations.py index 1cdb4ef0455..e3796a62d75 100644 --- a/tools/mo/openvino/tools/mo/back/offline_transformations.py +++ b/tools/mo/openvino/tools/mo/back/offline_transformations.py @@ -37,10 +37,57 @@ def compress_model(func: object): from openvino.offline_transformations_pybind import compress_model_transformation # pylint: disable=import-error,no-name-in-module compress_model_transformation(func) -def apply_offline_transformations(input_model: str, framework: str, transforms: list, compress_fp16=False): + +def add_layouts(ov_function, argv: argparse.Namespace): + from openvino.preprocess import PrePostProcessor # pylint: disable=no-name-in-module,import-error + from openvino.runtime import Layout # pylint: disable=import-error,no-name-in-module + + prep = PrePostProcessor(ov_function) + layout_values = argv.layout_values + if '' in layout_values: + if len(ov_function.inputs) == 1: + layout_values = { + list(ov_function.input().get_tensor().get_names())[0]: { + 'source_layout': layout_values[''].get('source_layout'), + 'target_layout': layout_values[''].get('target_layout') + } + } + else: + input_names = [list(ov_input.get_tensor().get_names())[0] for ov_input in ov_function.inputs] + raise Error('Layout without name can be specified for models with only one input, ' + 'but provided model has {} inputs: \'{}\'. ' + 'Please specify explicitly input/output name for --layout option' + .format(len(input_names), input_names)) + + set_layout_names = set(layout_values.keys()) + for idx, ov_input in enumerate(ov_function.inputs): + found = set.intersection(set(ov_input.get_tensor().get_names()), set_layout_names) + assert len(found) <= 1, 'More then one name point to the same node' + if len(found) == 1: + node_name = list(found)[0] + found_layout = layout_values[node_name] + if found_layout['source_layout']: + prep.input(node_name).network().set_layout(Layout(found_layout['source_layout'])) + if found_layout['target_layout']: + prep.input(node_name).tensor().set_layout(Layout(found_layout['target_layout'])) + + for idx, ov_output in enumerate(ov_function.outputs): + found = set.intersection(set(ov_output.get_tensor().get_names()), set_layout_names) + assert len(found) <= 1, 'More then one name point to the same node' + if len(found) == 1: + node_name = list(found)[0] + found_layout = layout_values[node_name] + if found_layout['source_layout']: + prep.output(node_name).network().set_layout(Layout(found_layout['source_layout'])) + if found_layout['target_layout']: + prep.output(node_name).tensor().set_layout(Layout(found_layout['target_layout'])) + prep.build() + + +def apply_offline_transformations(input_model: str, argv: argparse.Namespace): # This variable is only needed by GenerateMappingFile transformation # to produce correct mapping - extract_names = framework in ['tf', 'mxnet', 'kaldi'] + extract_names = argv.framework in ['tf', 'mxnet', 'kaldi'] from openvino.offline_transformations_pybind import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error @@ -57,24 +104,14 @@ def apply_offline_transformations(input_model: str, framework: str, transforms: func = read_model(input_model + "_tmp.xml") - apply_user_transformations(func, transforms) + add_layouts(func, argv) # TODO: replace with preprocessing + + apply_user_transformations(func, parse_transform(argv.transform)) apply_moc_transformations(func) - if compress_fp16: + if "compress_fp16" in argv and argv.compress_fp16: compress_model(func) serialize(func, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8')) path_to_mapping = input_model + ".mapping" generate_mapping_file(func, path_to_mapping.encode('utf-8'), extract_names) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--input_model") - parser.add_argument("--framework") - parser.add_argument("--transform") - parser.add_argument("--compress_fp16", action='store_true') - - args = parser.parse_args() - - apply_offline_transformations(args.input_model, args.framework, parse_transform(args.transform), args.compress_fp16) diff --git a/tools/mo/openvino/tools/mo/main.py b/tools/mo/openvino/tools/mo/main.py index dd57be48ba6..132c04d6c88 100644 --- a/tools/mo/openvino/tools/mo/main.py +++ b/tools/mo/openvino/tools/mo/main.py @@ -6,7 +6,6 @@ import datetime import logging as log import os import platform -import subprocess import sys import traceback from collections import OrderedDict @@ -28,10 +27,10 @@ from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_ from openvino.tools.mo.pipeline.common import prepare_emit_ir, get_ir_version from openvino.tools.mo.pipeline.unified import unified_pipeline from openvino.tools.mo.utils import import_extensions -from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_model_name, \ - get_common_cli_options, get_caffe_cli_options, get_tf_cli_options, get_mxnet_cli_options, get_kaldi_cli_options, \ - get_onnx_cli_options, get_mean_scale_dictionary, parse_tuple_pairs, get_freeze_placeholder_values, get_meta_info, \ - parse_transform, check_available_transforms +from openvino.tools.mo.utils.cli_parser import check_available_transforms, get_caffe_cli_options, \ + get_common_cli_options, get_freeze_placeholder_values, get_kaldi_cli_options, get_layout_values, \ + get_mean_scale_dictionary, get_meta_info, get_model_name, get_mxnet_cli_options, get_onnx_cli_options, \ + get_placeholder_shapes, get_tf_cli_options, get_tuple_values, parse_transform, parse_tuple_pairs from openvino.tools.mo.utils.error import Error, FrameworkError from openvino.tools.mo.utils.find_ie_version import find_ie_version from openvino.tools.mo.utils.get_ov_update_message import get_ov_update_message @@ -268,6 +267,7 @@ def arguments_post_parsing(argv: argparse.Namespace): scale_values = parse_tuple_pairs(argv.scale_values) mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input) argv.mean_scale_values = mean_scale + argv.layout_values = get_layout_values(argv.layout, argv.source_layout, argv.target_layout) if not os.path.exists(argv.output_dir): try: @@ -360,22 +360,14 @@ def emit_ir(graph: Graph, argv: argparse.Namespace): orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name)) return_code = "not executed" - # This try-except is additional reinsurance that the IE - # dependency search does not break the MO pipeline try: if not argv.legacy_ir_generation: - path_to_offline_transformations = os.path.join(os.path.realpath(os.path.dirname(__file__)), 'back', - 'offline_transformations.py') - cmd = [sys.executable, path_to_offline_transformations, - "--input_model", orig_model_name, - "--framework", argv.framework, - "--transform", argv.transform] + from openvino.tools.mo.back.offline_transformations import apply_offline_transformations + apply_offline_transformations(orig_model_name, argv) if "compress_fp16" in argv and argv.compress_fp16: - cmd += ["--compress_fp16"] # restore data_type cmd parameter argv.data_type = 'FP16' - status = subprocess.run(cmd, env=os.environ) - return_code = status.returncode + return_code = 0 except Exception as e: return_code = "failed" log.error(e) diff --git a/tools/mo/openvino/tools/mo/utils/cli_parser.py b/tools/mo/openvino/tools/mo/utils/cli_parser.py index 7ad9a8b4edf..5e79c75ef9c 100644 --- a/tools/mo/openvino/tools/mo/utils/cli_parser.py +++ b/tools/mo/openvino/tools/mo/utils/cli_parser.py @@ -309,6 +309,28 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None): 'The exact meaning and order ' + 'of channels depend on how the original model was trained.', default=()) + common_group.add_argument('--source_layout', + help='Layout of the input or output of the model in the framework. Layout can' + ' be specified in the short form, e.g. nhwc, or in complex form, e.g. [n,h,w,c].' + ' Example for many names: ' + 'in_name1([n,h,w,c]),in_name2(nc),out_name1(n),out_name2(nc). Layout can be ' + 'partially defined, "?" can be used to specify undefined layout for one dimension, ' + '"..." can be used to specify undefined layout for multiple dimensions, for example ' + '?c??, nc..., n...c, etc.', + default=()) + common_group.add_argument('--target_layout', + help='Same as --source_layout, but specifies target layout that will be in the model ' + 'after processing by ModelOptimizer.', + default=()) + common_group.add_argument('--layout', + help='Combination of --source_layout and --target_layout. Can\'t be used with either of ' + 'them. If model has one input it is sufficient to specify layout of this input, for' + ' example --layout nhwc. To specify layouts of many tensors, names must be provided,' + ' for example: --layout name1(nchw),name2(nc). It is possible to instruct ' + 'ModelOptimizer to change layout, for example: ' + '--layout name1(nhwc->nchw),name2(cn->nc). Also "*" in long layout form can be used' + ' to fuse dimensions, for example [n,c,...]->[n*c,…].', + default=()) # TODO: isn't it a weights precision type common_group.add_argument('--data_type', help='Data type for all intermediate tensors and weights. ' + @@ -417,6 +439,9 @@ def get_common_cli_options(model_name): d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model'] d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model'] d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model'] + d['source_layout'] = ['- Source layout', lambda x: x if x else 'Not specified'] + d['target_layout'] = ['- Target layout', lambda x: x if x else 'Not specified'] + d['layout'] = ['- Layout', lambda x: x if x else 'Not specified'] d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified'] d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified'] d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified'] @@ -835,6 +860,137 @@ def parse_input_value(input_value: str): return node_name, shape, value, data_type +def split_str_avoiding_square_brackets(s: str) -> list: + """ + Splits a string by comma, but skips commas inside square brackets. + :param s: string to split + :return: list of strings split by comma + """ + res = list() + skipping = 0 + last_idx = 0 + for i, c in enumerate(s): + if c == '[': + skipping += 1 + elif c == ']': + skipping -= 1 + elif c == ',' and skipping == 0: + res.append(s[last_idx:i]) + last_idx = i + 1 + res.append(s[last_idx:]) + return res + + +def split_layouts_by_arrow(s: str) -> tuple: + """ + Splits a layout string by first arrow (->). + :param s: string to split + :return: tuple containing source and target layouts + """ + arrow = s.find('->') + if arrow != -1: + source_layout = s[:arrow] + target_layout = s[arrow + 2:] + if source_layout == '': + source_layout = None + if target_layout == '': + target_layout = None + return source_layout, target_layout + else: + return s, None + + +def validate_layout(layout: str): + """ + Checks if layout is of valid format. + :param layout: string containing layout + :raises: if layout is incorrect + """ + valid_layout_re = re.compile(r'\[?[^\[\]\(\)\s]*\]?') + if layout and not valid_layout_re.fullmatch(layout): + raise Error('Invalid layout parsed: {}'.format(layout)) + + +def write_found_layout(name: str, found_layout: str, parsed: dict, dest: str = None): + """ + Writes found layout data to the 'parsed' dict. + :param name: name of the node to add layout + :param found_layout: string containing layout for the node + :param parsed: dict where result will be stored + :param dest: type of the command line: + * 'source' is --source_layout + * 'target' is --target_layout + * None is --layout + """ + s_layout = None + t_layout = None + if name in parsed: + s_layout = parsed[name]['source_layout'] + t_layout = parsed[name]['target_layout'] + if dest == 'source': + s_layout = found_layout + elif dest == 'target': + t_layout = found_layout + else: + s_layout, t_layout = split_layouts_by_arrow(found_layout) + validate_layout(s_layout) + validate_layout(t_layout) + parsed[name] = {'source_layout': s_layout, 'target_layout': t_layout} + + +def parse_layouts_by_destination(s: str, parsed: dict, dest: str = None) -> None: + """ + Parses layout command line to get all names and layouts from it. Adds all found data in the 'parsed' dict. + :param s: string to parse + :param parsed: dict where result will be stored + :param dest: type of the command line: + * 'source' is --source_layout + * 'target' is --target_layout + * None is --layout + """ + list_s = split_str_avoiding_square_brackets(s) + if len(list_s) == 1 and (list_s[0][-1] not in ')]' or (list_s[0][0] == '[' and list_s[0][-1] == ']')): + # single layout case + write_found_layout('', list_s[0], parsed, dest) + else: + for layout_str in list_s: + # case for: "name1(nhwc->[n,c,h,w])" + p1 = re.compile(r'(\S+)\((\S+)\)') + m1 = p1.match(layout_str) + # case for: "name1[n,h,w,c]->[n,c,h,w]" + p2 = re.compile(r'(\S+)(\[\S*\])') + m2 = p2.match(layout_str) + if m1: + found_g = m1.groups() + elif m2: + found_g = m2.groups() + else: + raise Error("More then one layout provided for --{}layout without providing name.".format( + dest + '_' if dest else '')) + write_found_layout(found_g[0], found_g[1], parsed, dest) + + +def get_layout_values(argv_layout: str = '', argv_source_layout: str = '', argv_target_layout: str = ''): + """ + Parses layout string. + :param argv_layout: string with a list of layouts passed as a --layout. + :param argv_source_layout: string with a list of layouts passed as a --source_layout. + :param argv_target_layout: string with a list of layouts passed as a --target_layout. + :return: dict with names and layouts associated + """ + if argv_layout and (argv_source_layout or argv_target_layout): + raise Error("--layout is used as well as --source_layout and/or --target_layout which is not allowed, please " + "use one of them.") + res = {} + if argv_layout: + parse_layouts_by_destination(argv_layout, res) + if argv_source_layout: + parse_layouts_by_destination(argv_source_layout, res, 'source') + if argv_target_layout: + parse_layouts_by_destination(argv_target_layout, res, 'target') + return res + + def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_value: str): """ Parses values for placeholder freezing and input node names diff --git a/tools/mo/unit_tests/mo/frontend_ngraph_test_actual.py b/tools/mo/unit_tests/mo/frontend_ngraph_test_actual.py index eb3ce64e9c3..ecb6bc6a4ea 100644 --- a/tools/mo/unit_tests/mo/frontend_ngraph_test_actual.py +++ b/tools/mo/unit_tests/mo/frontend_ngraph_test_actual.py @@ -56,6 +56,9 @@ def replaceArgsHelper(log_level='DEBUG', batch=None, mean_values=None, scale_values=None, + layout=None, + source_layout=None, + target_layout=None, output_dir='.', freeze_placeholder_with_value=None): return argparse.Namespace( @@ -72,6 +75,9 @@ def replaceArgsHelper(log_level='DEBUG', batch=batch, mean_values=mean_values, scale_values=scale_values, + layout=layout, + source_layout=source_layout, + target_layout=target_layout, output_dir=output_dir, freeze_placeholder_with_value=freeze_placeholder_with_value, use_legacy_frontend=None, diff --git a/tools/mo/unit_tests/mo/utils/cli_parser_test.py b/tools/mo/unit_tests/mo/utils/cli_parser_test.py index d26f82121e9..eb6a917e2e2 100644 --- a/tools/mo/unit_tests/mo/utils/cli_parser_test.py +++ b/tools/mo/unit_tests/mo/utils/cli_parser_test.py @@ -12,9 +12,10 @@ from unittest.mock import patch import numpy as np import numpy.testing as npt -from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \ +from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \ + get_model_name, \ parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \ - readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms + readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values from openvino.tools.mo.utils.error import Error @@ -380,6 +381,7 @@ class TestingMeanScaleGetter(unittest.TestCase): def test_input_without_values(self): self.assertRaises(Error, parse_tuple_pairs, "input1,input2") + class TestSingleTupleParsing(unittest.TestCase): def test_get_values_ideal(self): values = "(1.11, 22.22, 333.333)" @@ -476,7 +478,8 @@ class TestShapesParsing(unittest.TestCase): for i in exp_res.keys(): npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) - placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), + 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} input_node_names_ref = "inp1,inp2,inp3" self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) for i in placeholder_values_ref.keys(): @@ -492,7 +495,8 @@ class TestShapesParsing(unittest.TestCase): for i in exp_res.keys(): npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) - placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), + 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} input_node_names_ref = "inp1,inp2,inp3" self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) for i in placeholder_values_ref.keys(): @@ -510,8 +514,10 @@ class TestShapesParsing(unittest.TestCase): self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): npt.assert_array_equal(result[i], exp_res[i]) - placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value) - placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],), + placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, + argv_freeze_placeholder_with_value) + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), + 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'], ), 'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])} input_node_names_ref = "inp1,inp2,inp3" self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys()))) @@ -772,6 +778,7 @@ class TestShapesParsing(unittest.TestCase): input_shapes = "(12,4,1),(4,-6,8)" self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes) + class TestModelNameParsing(unittest.TestCase): def test_model_name_ideal(self): model_name = '/home/models/mymodel.caffemodel' @@ -923,9 +930,9 @@ class TransformChecker(unittest.TestCase): def test_multiple_passes_with_args2(self): self.assertEqual(parse_transform("LowLatency2[use_const_initializer=True,False],DummyPass1," "DummyPass2[types=ReLU,PReLU;values=1,2,3]"), - [("LowLatency2", {"use_const_initializer": [True, False]}), - ("DummyPass1", {}), - ("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1,2,3]})]) + [("LowLatency2", {"use_const_initializer": [True, False]}), + ("DummyPass1", {}), + ("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1, 2, 3]})]) def test_multiple_passes_no_args(self): self.assertEqual(parse_transform("DummyPass,LowLatency22"), @@ -967,3 +974,297 @@ class TransformChecker(unittest.TestCase): def test_check_dummy_pass_is_available(self, available_transformations): available_transformations.return_value = {"LowLatency2": None} self.assertRaises(Error, check_available_transforms, [("DummyPass", "")]) + + +class TestLayoutParsing(unittest.TestCase): + def test_get_layout_1(self): + argv_layout = "name1([n,h,w,c]),name2([n,h,w,c]->[n,c,h,w])" + result = get_layout_values(argv_layout) + exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None}, + 'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_2(self): + argv_layout = "name1(nhwc),name2(nhwc->nchw)" + result = get_layout_values(argv_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None}, + 'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_3(self): + argv_layout = "name1(n...c),name2(n...c->nc...)" + result = get_layout_values(argv_layout) + exp_res = {'name1': {'source_layout': 'n...c', 'target_layout': None}, + 'name2': {'source_layout': 'n...c', 'target_layout': 'nc...'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_4(self): + argv_layout = "nhwc" + result = get_layout_values(argv_layout) + exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_5(self): + argv_layout = "[n,h,w,c]" + result = get_layout_values(argv_layout) + exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_6(self): + argv_layout = "nhwc->nchw" + result = get_layout_values(argv_layout) + exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_7(self): + argv_layout = "[n,h,w,c]->[n,c,h,w]" + result = get_layout_values(argv_layout) + exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_scalar(self): + argv_layout = "name1(nhwc),name2([])" + result = get_layout_values(argv_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None}, + 'name2': {'source_layout': '[]', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_layout_1(self): + argv_source_layout = "[n,h,w,c]" + result = get_layout_values(argv_source_layout=argv_source_layout) + exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_layout_2(self): + argv_source_layout = "nhwc" + result = get_layout_values(argv_source_layout=argv_source_layout) + exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_layout_3(self): + argv_source_layout = "name1(nhwc),name2(nchw)" + result = get_layout_values(argv_source_layout=argv_source_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None}, + 'name2': {'source_layout': 'nchw', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_layout_4(self): + argv_source_layout = "name1([n,h,w,c]),name2([n,c,h,w])" + result = get_layout_values(argv_source_layout=argv_source_layout) + exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None}, + 'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_layout_5(self): + argv_source_layout = "name1(nhwc),name2([n,c,h,w])" + result = get_layout_values(argv_source_layout=argv_source_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None}, + 'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_layout_6(self): + argv_source_layout = "name1(nhwc),name2[n,c,h,w]" + result = get_layout_values(argv_source_layout=argv_source_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None}, + 'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_layout_scalar(self): + argv_source_layout = "name1(nhwc),name2([])" + result = get_layout_values(argv_source_layout=argv_source_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None}, + 'name2': {'source_layout': '[]', 'target_layout': None}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_target_layout_1(self): + argv_target_layout = "[n,h,w,c]" + result = get_layout_values(argv_target_layout=argv_target_layout) + exp_res = {'': {'source_layout': None, 'target_layout': '[n,h,w,c]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_target_layout_2(self): + argv_target_layout = "nhwc" + result = get_layout_values(argv_target_layout=argv_target_layout) + exp_res = {'': {'source_layout': None, 'target_layout': 'nhwc'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_target_layout_3(self): + argv_target_layout = "name1(nhwc),name2(nchw)" + result = get_layout_values(argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'}, + 'name2': {'source_layout': None, 'target_layout': 'nchw'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_target_layout_4(self): + argv_target_layout = "name1([n,h,w,c]),name2([n,c,h,w])" + result = get_layout_values(argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': None, 'target_layout': '[n,h,w,c]'}, + 'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_target_layout_5(self): + argv_target_layout = "name1(nhwc),name2([n,c,h,w])" + result = get_layout_values(argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'}, + 'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_target_layout_6(self): + argv_target_layout = "name1(nhwc),name2[n,c,h,w]" + result = get_layout_values(argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'}, + 'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_target_layout_scalar(self): + argv_target_layout = "name1(nhwc),name2[]" + result = get_layout_values(argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'}, + 'name2': {'source_layout': None, 'target_layout': '[]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_target_layout_1(self): + argv_source_layout = "[n,h,w,c]" + argv_target_layout = "[n,c,h,w]" + result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout) + exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_target_layout_2(self): + argv_source_layout = "nhwc" + argv_target_layout = "nchw" + result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout) + exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_target_layout_3(self): + argv_source_layout = "name1(nhwc),name2(nhwc)" + argv_target_layout = "name1(nchw),name2(nchw)" + result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'}, + 'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_target_layout_4(self): + argv_source_layout = "name1([n,h,w,c]),name2([n,h,w,c])" + argv_target_layout = "name1([n,c,h,w]),name2([n,c,h,w])" + result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}, + 'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_target_layout_5(self): + argv_source_layout = "name1(nhwc),name2[n,h,w,c]" + argv_target_layout = "name1(nchw),name2[n,c,h,w]" + result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'}, + 'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_source_target_layout_scalar(self): + argv_source_layout = "name1(nhwc),name2[]" + argv_target_layout = "name1(nchw),name2[]" + result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout) + exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'}, + 'name2': {'source_layout': '[]', 'target_layout': '[]'}} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_layout_raises_if_layout_and_source_layout_provided(self): + argv_layout = "nhwc" + argv_source_layout = "nhwc" + with self.assertRaises(Error): + get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout) + + def test_get_layout_raises_if_layout_and_target_layout_provided(self): + argv_layout = "nhwc->nchw" + argv_target_layout = "nchw" + with self.assertRaises(Error): + get_layout_values(argv_layout=argv_layout, argv_target_layout=argv_target_layout) + + def test_get_layout_raises_if_layout_with_source_and_target_layout_provided(self): + argv_layout = "nhwc->nchw" + argv_source_layout = "nhwc" + argv_target_layout = "nchw" + with self.assertRaises(Error): + get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout, + argv_target_layout=argv_target_layout) + + def test_get_layout_raises_incorrect_format(self): + argv_layout = "name[n,h,w,c]->nchw" + with self.assertRaises(Error): + res = get_layout_values(argv_layout=argv_layout) + print(res) + + def test_get_layout_raises_multiple_layouts_without_names(self): + argv_layout = "nhwc->nchw,nhwc->nchw" + with self.assertRaises(Error): + res = get_layout_values(argv_layout=argv_layout) + print(res) + + def test_get_layout_raises_multiple_layouts_without_names_source_layout(self): + argv_source_layout = "nhwc,nhwc" + with self.assertRaises(Error): + res = get_layout_values(argv_source_layout=argv_source_layout) + print(res) + + def test_get_layout_raises_multiple_layouts_without_names_target_layout(self): + argv_target_layout = "nchw,nchw" + with self.assertRaises(Error): + res = get_layout_values(argv_target_layout=argv_target_layout) + print(res)