Add layout commands in MO (#8829)
* Add layout support in MO * Apply review feedback
This commit is contained in:
parent
2ff4ef2e4f
commit
51947eeb3d
@ -37,10 +37,57 @@ def compress_model(func: object):
|
|||||||
from openvino.offline_transformations_pybind import compress_model_transformation # pylint: disable=import-error,no-name-in-module
|
from openvino.offline_transformations_pybind import compress_model_transformation # pylint: disable=import-error,no-name-in-module
|
||||||
compress_model_transformation(func)
|
compress_model_transformation(func)
|
||||||
|
|
||||||
def apply_offline_transformations(input_model: str, framework: str, transforms: list, compress_fp16=False):
|
|
||||||
|
def add_layouts(ov_function, argv: argparse.Namespace):
|
||||||
|
from openvino.preprocess import PrePostProcessor # pylint: disable=no-name-in-module,import-error
|
||||||
|
from openvino.runtime import Layout # pylint: disable=import-error,no-name-in-module
|
||||||
|
|
||||||
|
prep = PrePostProcessor(ov_function)
|
||||||
|
layout_values = argv.layout_values
|
||||||
|
if '' in layout_values:
|
||||||
|
if len(ov_function.inputs) == 1:
|
||||||
|
layout_values = {
|
||||||
|
list(ov_function.input().get_tensor().get_names())[0]: {
|
||||||
|
'source_layout': layout_values[''].get('source_layout'),
|
||||||
|
'target_layout': layout_values[''].get('target_layout')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
input_names = [list(ov_input.get_tensor().get_names())[0] for ov_input in ov_function.inputs]
|
||||||
|
raise Error('Layout without name can be specified for models with only one input, '
|
||||||
|
'but provided model has {} inputs: \'{}\'. '
|
||||||
|
'Please specify explicitly input/output name for --layout option'
|
||||||
|
.format(len(input_names), input_names))
|
||||||
|
|
||||||
|
set_layout_names = set(layout_values.keys())
|
||||||
|
for idx, ov_input in enumerate(ov_function.inputs):
|
||||||
|
found = set.intersection(set(ov_input.get_tensor().get_names()), set_layout_names)
|
||||||
|
assert len(found) <= 1, 'More then one name point to the same node'
|
||||||
|
if len(found) == 1:
|
||||||
|
node_name = list(found)[0]
|
||||||
|
found_layout = layout_values[node_name]
|
||||||
|
if found_layout['source_layout']:
|
||||||
|
prep.input(node_name).network().set_layout(Layout(found_layout['source_layout']))
|
||||||
|
if found_layout['target_layout']:
|
||||||
|
prep.input(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
|
||||||
|
|
||||||
|
for idx, ov_output in enumerate(ov_function.outputs):
|
||||||
|
found = set.intersection(set(ov_output.get_tensor().get_names()), set_layout_names)
|
||||||
|
assert len(found) <= 1, 'More then one name point to the same node'
|
||||||
|
if len(found) == 1:
|
||||||
|
node_name = list(found)[0]
|
||||||
|
found_layout = layout_values[node_name]
|
||||||
|
if found_layout['source_layout']:
|
||||||
|
prep.output(node_name).network().set_layout(Layout(found_layout['source_layout']))
|
||||||
|
if found_layout['target_layout']:
|
||||||
|
prep.output(node_name).tensor().set_layout(Layout(found_layout['target_layout']))
|
||||||
|
prep.build()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_offline_transformations(input_model: str, argv: argparse.Namespace):
|
||||||
# This variable is only needed by GenerateMappingFile transformation
|
# This variable is only needed by GenerateMappingFile transformation
|
||||||
# to produce correct mapping
|
# to produce correct mapping
|
||||||
extract_names = framework in ['tf', 'mxnet', 'kaldi']
|
extract_names = argv.framework in ['tf', 'mxnet', 'kaldi']
|
||||||
|
|
||||||
from openvino.offline_transformations_pybind import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module
|
from openvino.offline_transformations_pybind import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module
|
||||||
from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error
|
from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error
|
||||||
@ -57,24 +104,14 @@ def apply_offline_transformations(input_model: str, framework: str, transforms:
|
|||||||
|
|
||||||
func = read_model(input_model + "_tmp.xml")
|
func = read_model(input_model + "_tmp.xml")
|
||||||
|
|
||||||
apply_user_transformations(func, transforms)
|
add_layouts(func, argv) # TODO: replace with preprocessing
|
||||||
|
|
||||||
|
apply_user_transformations(func, parse_transform(argv.transform))
|
||||||
apply_moc_transformations(func)
|
apply_moc_transformations(func)
|
||||||
|
|
||||||
if compress_fp16:
|
if "compress_fp16" in argv and argv.compress_fp16:
|
||||||
compress_model(func)
|
compress_model(func)
|
||||||
|
|
||||||
serialize(func, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8'))
|
serialize(func, str(input_model + ".xml").encode('utf-8'), (input_model + ".bin").encode('utf-8'))
|
||||||
path_to_mapping = input_model + ".mapping"
|
path_to_mapping = input_model + ".mapping"
|
||||||
generate_mapping_file(func, path_to_mapping.encode('utf-8'), extract_names)
|
generate_mapping_file(func, path_to_mapping.encode('utf-8'), extract_names)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--input_model")
|
|
||||||
parser.add_argument("--framework")
|
|
||||||
parser.add_argument("--transform")
|
|
||||||
parser.add_argument("--compress_fp16", action='store_true')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
apply_offline_transformations(args.input_model, args.framework, parse_transform(args.transform), args.compress_fp16)
|
|
||||||
|
@ -6,7 +6,6 @@ import datetime
|
|||||||
import logging as log
|
import logging as log
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -28,10 +27,10 @@ from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_
|
|||||||
from openvino.tools.mo.pipeline.common import prepare_emit_ir, get_ir_version
|
from openvino.tools.mo.pipeline.common import prepare_emit_ir, get_ir_version
|
||||||
from openvino.tools.mo.pipeline.unified import unified_pipeline
|
from openvino.tools.mo.pipeline.unified import unified_pipeline
|
||||||
from openvino.tools.mo.utils import import_extensions
|
from openvino.tools.mo.utils import import_extensions
|
||||||
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_model_name, \
|
from openvino.tools.mo.utils.cli_parser import check_available_transforms, get_caffe_cli_options, \
|
||||||
get_common_cli_options, get_caffe_cli_options, get_tf_cli_options, get_mxnet_cli_options, get_kaldi_cli_options, \
|
get_common_cli_options, get_freeze_placeholder_values, get_kaldi_cli_options, get_layout_values, \
|
||||||
get_onnx_cli_options, get_mean_scale_dictionary, parse_tuple_pairs, get_freeze_placeholder_values, get_meta_info, \
|
get_mean_scale_dictionary, get_meta_info, get_model_name, get_mxnet_cli_options, get_onnx_cli_options, \
|
||||||
parse_transform, check_available_transforms
|
get_placeholder_shapes, get_tf_cli_options, get_tuple_values, parse_transform, parse_tuple_pairs
|
||||||
from openvino.tools.mo.utils.error import Error, FrameworkError
|
from openvino.tools.mo.utils.error import Error, FrameworkError
|
||||||
from openvino.tools.mo.utils.find_ie_version import find_ie_version
|
from openvino.tools.mo.utils.find_ie_version import find_ie_version
|
||||||
from openvino.tools.mo.utils.get_ov_update_message import get_ov_update_message
|
from openvino.tools.mo.utils.get_ov_update_message import get_ov_update_message
|
||||||
@ -268,6 +267,7 @@ def arguments_post_parsing(argv: argparse.Namespace):
|
|||||||
scale_values = parse_tuple_pairs(argv.scale_values)
|
scale_values = parse_tuple_pairs(argv.scale_values)
|
||||||
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
|
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
|
||||||
argv.mean_scale_values = mean_scale
|
argv.mean_scale_values = mean_scale
|
||||||
|
argv.layout_values = get_layout_values(argv.layout, argv.source_layout, argv.target_layout)
|
||||||
|
|
||||||
if not os.path.exists(argv.output_dir):
|
if not os.path.exists(argv.output_dir):
|
||||||
try:
|
try:
|
||||||
@ -360,22 +360,14 @@ def emit_ir(graph: Graph, argv: argparse.Namespace):
|
|||||||
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
|
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
|
||||||
|
|
||||||
return_code = "not executed"
|
return_code = "not executed"
|
||||||
# This try-except is additional reinsurance that the IE
|
|
||||||
# dependency search does not break the MO pipeline
|
|
||||||
try:
|
try:
|
||||||
if not argv.legacy_ir_generation:
|
if not argv.legacy_ir_generation:
|
||||||
path_to_offline_transformations = os.path.join(os.path.realpath(os.path.dirname(__file__)), 'back',
|
from openvino.tools.mo.back.offline_transformations import apply_offline_transformations
|
||||||
'offline_transformations.py')
|
apply_offline_transformations(orig_model_name, argv)
|
||||||
cmd = [sys.executable, path_to_offline_transformations,
|
|
||||||
"--input_model", orig_model_name,
|
|
||||||
"--framework", argv.framework,
|
|
||||||
"--transform", argv.transform]
|
|
||||||
if "compress_fp16" in argv and argv.compress_fp16:
|
if "compress_fp16" in argv and argv.compress_fp16:
|
||||||
cmd += ["--compress_fp16"]
|
|
||||||
# restore data_type cmd parameter
|
# restore data_type cmd parameter
|
||||||
argv.data_type = 'FP16'
|
argv.data_type = 'FP16'
|
||||||
status = subprocess.run(cmd, env=os.environ)
|
return_code = 0
|
||||||
return_code = status.returncode
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return_code = "failed"
|
return_code = "failed"
|
||||||
log.error(e)
|
log.error(e)
|
||||||
|
@ -309,6 +309,28 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
|||||||
'The exact meaning and order ' +
|
'The exact meaning and order ' +
|
||||||
'of channels depend on how the original model was trained.',
|
'of channels depend on how the original model was trained.',
|
||||||
default=())
|
default=())
|
||||||
|
common_group.add_argument('--source_layout',
|
||||||
|
help='Layout of the input or output of the model in the framework. Layout can'
|
||||||
|
' be specified in the short form, e.g. nhwc, or in complex form, e.g. [n,h,w,c].'
|
||||||
|
' Example for many names: '
|
||||||
|
'in_name1([n,h,w,c]),in_name2(nc),out_name1(n),out_name2(nc). Layout can be '
|
||||||
|
'partially defined, "?" can be used to specify undefined layout for one dimension, '
|
||||||
|
'"..." can be used to specify undefined layout for multiple dimensions, for example '
|
||||||
|
'?c??, nc..., n...c, etc.',
|
||||||
|
default=())
|
||||||
|
common_group.add_argument('--target_layout',
|
||||||
|
help='Same as --source_layout, but specifies target layout that will be in the model '
|
||||||
|
'after processing by ModelOptimizer.',
|
||||||
|
default=())
|
||||||
|
common_group.add_argument('--layout',
|
||||||
|
help='Combination of --source_layout and --target_layout. Can\'t be used with either of '
|
||||||
|
'them. If model has one input it is sufficient to specify layout of this input, for'
|
||||||
|
' example --layout nhwc. To specify layouts of many tensors, names must be provided,'
|
||||||
|
' for example: --layout name1(nchw),name2(nc). It is possible to instruct '
|
||||||
|
'ModelOptimizer to change layout, for example: '
|
||||||
|
'--layout name1(nhwc->nchw),name2(cn->nc). Also "*" in long layout form can be used'
|
||||||
|
' to fuse dimensions, for example [n,c,...]->[n*c,…].',
|
||||||
|
default=())
|
||||||
# TODO: isn't it a weights precision type
|
# TODO: isn't it a weights precision type
|
||||||
common_group.add_argument('--data_type',
|
common_group.add_argument('--data_type',
|
||||||
help='Data type for all intermediate tensors and weights. ' +
|
help='Data type for all intermediate tensors and weights. ' +
|
||||||
@ -417,6 +439,9 @@ def get_common_cli_options(model_name):
|
|||||||
d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model']
|
d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model']
|
||||||
d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model']
|
d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model']
|
||||||
d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model']
|
d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model']
|
||||||
|
d['source_layout'] = ['- Source layout', lambda x: x if x else 'Not specified']
|
||||||
|
d['target_layout'] = ['- Target layout', lambda x: x if x else 'Not specified']
|
||||||
|
d['layout'] = ['- Layout', lambda x: x if x else 'Not specified']
|
||||||
d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified']
|
d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified']
|
||||||
d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified']
|
d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified']
|
||||||
d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified']
|
d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified']
|
||||||
@ -835,6 +860,137 @@ def parse_input_value(input_value: str):
|
|||||||
return node_name, shape, value, data_type
|
return node_name, shape, value, data_type
|
||||||
|
|
||||||
|
|
||||||
|
def split_str_avoiding_square_brackets(s: str) -> list:
|
||||||
|
"""
|
||||||
|
Splits a string by comma, but skips commas inside square brackets.
|
||||||
|
:param s: string to split
|
||||||
|
:return: list of strings split by comma
|
||||||
|
"""
|
||||||
|
res = list()
|
||||||
|
skipping = 0
|
||||||
|
last_idx = 0
|
||||||
|
for i, c in enumerate(s):
|
||||||
|
if c == '[':
|
||||||
|
skipping += 1
|
||||||
|
elif c == ']':
|
||||||
|
skipping -= 1
|
||||||
|
elif c == ',' and skipping == 0:
|
||||||
|
res.append(s[last_idx:i])
|
||||||
|
last_idx = i + 1
|
||||||
|
res.append(s[last_idx:])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def split_layouts_by_arrow(s: str) -> tuple:
|
||||||
|
"""
|
||||||
|
Splits a layout string by first arrow (->).
|
||||||
|
:param s: string to split
|
||||||
|
:return: tuple containing source and target layouts
|
||||||
|
"""
|
||||||
|
arrow = s.find('->')
|
||||||
|
if arrow != -1:
|
||||||
|
source_layout = s[:arrow]
|
||||||
|
target_layout = s[arrow + 2:]
|
||||||
|
if source_layout == '':
|
||||||
|
source_layout = None
|
||||||
|
if target_layout == '':
|
||||||
|
target_layout = None
|
||||||
|
return source_layout, target_layout
|
||||||
|
else:
|
||||||
|
return s, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_layout(layout: str):
|
||||||
|
"""
|
||||||
|
Checks if layout is of valid format.
|
||||||
|
:param layout: string containing layout
|
||||||
|
:raises: if layout is incorrect
|
||||||
|
"""
|
||||||
|
valid_layout_re = re.compile(r'\[?[^\[\]\(\)\s]*\]?')
|
||||||
|
if layout and not valid_layout_re.fullmatch(layout):
|
||||||
|
raise Error('Invalid layout parsed: {}'.format(layout))
|
||||||
|
|
||||||
|
|
||||||
|
def write_found_layout(name: str, found_layout: str, parsed: dict, dest: str = None):
|
||||||
|
"""
|
||||||
|
Writes found layout data to the 'parsed' dict.
|
||||||
|
:param name: name of the node to add layout
|
||||||
|
:param found_layout: string containing layout for the node
|
||||||
|
:param parsed: dict where result will be stored
|
||||||
|
:param dest: type of the command line:
|
||||||
|
* 'source' is --source_layout
|
||||||
|
* 'target' is --target_layout
|
||||||
|
* None is --layout
|
||||||
|
"""
|
||||||
|
s_layout = None
|
||||||
|
t_layout = None
|
||||||
|
if name in parsed:
|
||||||
|
s_layout = parsed[name]['source_layout']
|
||||||
|
t_layout = parsed[name]['target_layout']
|
||||||
|
if dest == 'source':
|
||||||
|
s_layout = found_layout
|
||||||
|
elif dest == 'target':
|
||||||
|
t_layout = found_layout
|
||||||
|
else:
|
||||||
|
s_layout, t_layout = split_layouts_by_arrow(found_layout)
|
||||||
|
validate_layout(s_layout)
|
||||||
|
validate_layout(t_layout)
|
||||||
|
parsed[name] = {'source_layout': s_layout, 'target_layout': t_layout}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_layouts_by_destination(s: str, parsed: dict, dest: str = None) -> None:
|
||||||
|
"""
|
||||||
|
Parses layout command line to get all names and layouts from it. Adds all found data in the 'parsed' dict.
|
||||||
|
:param s: string to parse
|
||||||
|
:param parsed: dict where result will be stored
|
||||||
|
:param dest: type of the command line:
|
||||||
|
* 'source' is --source_layout
|
||||||
|
* 'target' is --target_layout
|
||||||
|
* None is --layout
|
||||||
|
"""
|
||||||
|
list_s = split_str_avoiding_square_brackets(s)
|
||||||
|
if len(list_s) == 1 and (list_s[0][-1] not in ')]' or (list_s[0][0] == '[' and list_s[0][-1] == ']')):
|
||||||
|
# single layout case
|
||||||
|
write_found_layout('', list_s[0], parsed, dest)
|
||||||
|
else:
|
||||||
|
for layout_str in list_s:
|
||||||
|
# case for: "name1(nhwc->[n,c,h,w])"
|
||||||
|
p1 = re.compile(r'(\S+)\((\S+)\)')
|
||||||
|
m1 = p1.match(layout_str)
|
||||||
|
# case for: "name1[n,h,w,c]->[n,c,h,w]"
|
||||||
|
p2 = re.compile(r'(\S+)(\[\S*\])')
|
||||||
|
m2 = p2.match(layout_str)
|
||||||
|
if m1:
|
||||||
|
found_g = m1.groups()
|
||||||
|
elif m2:
|
||||||
|
found_g = m2.groups()
|
||||||
|
else:
|
||||||
|
raise Error("More then one layout provided for --{}layout without providing name.".format(
|
||||||
|
dest + '_' if dest else ''))
|
||||||
|
write_found_layout(found_g[0], found_g[1], parsed, dest)
|
||||||
|
|
||||||
|
|
||||||
|
def get_layout_values(argv_layout: str = '', argv_source_layout: str = '', argv_target_layout: str = ''):
|
||||||
|
"""
|
||||||
|
Parses layout string.
|
||||||
|
:param argv_layout: string with a list of layouts passed as a --layout.
|
||||||
|
:param argv_source_layout: string with a list of layouts passed as a --source_layout.
|
||||||
|
:param argv_target_layout: string with a list of layouts passed as a --target_layout.
|
||||||
|
:return: dict with names and layouts associated
|
||||||
|
"""
|
||||||
|
if argv_layout and (argv_source_layout or argv_target_layout):
|
||||||
|
raise Error("--layout is used as well as --source_layout and/or --target_layout which is not allowed, please "
|
||||||
|
"use one of them.")
|
||||||
|
res = {}
|
||||||
|
if argv_layout:
|
||||||
|
parse_layouts_by_destination(argv_layout, res)
|
||||||
|
if argv_source_layout:
|
||||||
|
parse_layouts_by_destination(argv_source_layout, res, 'source')
|
||||||
|
if argv_target_layout:
|
||||||
|
parse_layouts_by_destination(argv_target_layout, res, 'target')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_value: str):
|
def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_value: str):
|
||||||
"""
|
"""
|
||||||
Parses values for placeholder freezing and input node names
|
Parses values for placeholder freezing and input node names
|
||||||
|
@ -56,6 +56,9 @@ def replaceArgsHelper(log_level='DEBUG',
|
|||||||
batch=None,
|
batch=None,
|
||||||
mean_values=None,
|
mean_values=None,
|
||||||
scale_values=None,
|
scale_values=None,
|
||||||
|
layout=None,
|
||||||
|
source_layout=None,
|
||||||
|
target_layout=None,
|
||||||
output_dir='.',
|
output_dir='.',
|
||||||
freeze_placeholder_with_value=None):
|
freeze_placeholder_with_value=None):
|
||||||
return argparse.Namespace(
|
return argparse.Namespace(
|
||||||
@ -72,6 +75,9 @@ def replaceArgsHelper(log_level='DEBUG',
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
mean_values=mean_values,
|
mean_values=mean_values,
|
||||||
scale_values=scale_values,
|
scale_values=scale_values,
|
||||||
|
layout=layout,
|
||||||
|
source_layout=source_layout,
|
||||||
|
target_layout=target_layout,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
freeze_placeholder_with_value=freeze_placeholder_with_value,
|
freeze_placeholder_with_value=freeze_placeholder_with_value,
|
||||||
use_legacy_frontend=None,
|
use_legacy_frontend=None,
|
||||||
|
@ -12,9 +12,10 @@ from unittest.mock import patch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.testing as npt
|
import numpy.testing as npt
|
||||||
|
|
||||||
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \
|
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \
|
||||||
|
get_model_name, \
|
||||||
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
|
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
|
||||||
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms
|
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values
|
||||||
from openvino.tools.mo.utils.error import Error
|
from openvino.tools.mo.utils.error import Error
|
||||||
|
|
||||||
|
|
||||||
@ -380,6 +381,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
|||||||
def test_input_without_values(self):
|
def test_input_without_values(self):
|
||||||
self.assertRaises(Error, parse_tuple_pairs, "input1,input2")
|
self.assertRaises(Error, parse_tuple_pairs, "input1,input2")
|
||||||
|
|
||||||
|
|
||||||
class TestSingleTupleParsing(unittest.TestCase):
|
class TestSingleTupleParsing(unittest.TestCase):
|
||||||
def test_get_values_ideal(self):
|
def test_get_values_ideal(self):
|
||||||
values = "(1.11, 22.22, 333.333)"
|
values = "(1.11, 22.22, 333.333)"
|
||||||
@ -476,7 +478,8 @@ class TestShapesParsing(unittest.TestCase):
|
|||||||
for i in exp_res.keys():
|
for i in exp_res.keys():
|
||||||
npt.assert_array_equal(result[i], exp_res[i])
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
|
||||||
|
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||||
input_node_names_ref = "inp1,inp2,inp3"
|
input_node_names_ref = "inp1,inp2,inp3"
|
||||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||||
for i in placeholder_values_ref.keys():
|
for i in placeholder_values_ref.keys():
|
||||||
@ -492,7 +495,8 @@ class TestShapesParsing(unittest.TestCase):
|
|||||||
for i in exp_res.keys():
|
for i in exp_res.keys():
|
||||||
npt.assert_array_equal(result[i], exp_res[i])
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
|
||||||
|
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||||
input_node_names_ref = "inp1,inp2,inp3"
|
input_node_names_ref = "inp1,inp2,inp3"
|
||||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||||
for i in placeholder_values_ref.keys():
|
for i in placeholder_values_ref.keys():
|
||||||
@ -510,8 +514,10 @@ class TestShapesParsing(unittest.TestCase):
|
|||||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
for i in exp_res.keys():
|
for i in exp_res.keys():
|
||||||
npt.assert_array_equal(result[i], exp_res[i])
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input,
|
||||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
|
argv_freeze_placeholder_with_value)
|
||||||
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']),
|
||||||
|
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'], ),
|
||||||
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
|
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
|
||||||
input_node_names_ref = "inp1,inp2,inp3"
|
input_node_names_ref = "inp1,inp2,inp3"
|
||||||
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
|
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
|
||||||
@ -772,6 +778,7 @@ class TestShapesParsing(unittest.TestCase):
|
|||||||
input_shapes = "(12,4,1),(4,-6,8)"
|
input_shapes = "(12,4,1),(4,-6,8)"
|
||||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||||
|
|
||||||
|
|
||||||
class TestModelNameParsing(unittest.TestCase):
|
class TestModelNameParsing(unittest.TestCase):
|
||||||
def test_model_name_ideal(self):
|
def test_model_name_ideal(self):
|
||||||
model_name = '/home/models/mymodel.caffemodel'
|
model_name = '/home/models/mymodel.caffemodel'
|
||||||
@ -923,9 +930,9 @@ class TransformChecker(unittest.TestCase):
|
|||||||
def test_multiple_passes_with_args2(self):
|
def test_multiple_passes_with_args2(self):
|
||||||
self.assertEqual(parse_transform("LowLatency2[use_const_initializer=True,False],DummyPass1,"
|
self.assertEqual(parse_transform("LowLatency2[use_const_initializer=True,False],DummyPass1,"
|
||||||
"DummyPass2[types=ReLU,PReLU;values=1,2,3]"),
|
"DummyPass2[types=ReLU,PReLU;values=1,2,3]"),
|
||||||
[("LowLatency2", {"use_const_initializer": [True, False]}),
|
[("LowLatency2", {"use_const_initializer": [True, False]}),
|
||||||
("DummyPass1", {}),
|
("DummyPass1", {}),
|
||||||
("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1,2,3]})])
|
("DummyPass2", {"types": ["ReLU", "PReLU"], "values": [1, 2, 3]})])
|
||||||
|
|
||||||
def test_multiple_passes_no_args(self):
|
def test_multiple_passes_no_args(self):
|
||||||
self.assertEqual(parse_transform("DummyPass,LowLatency22"),
|
self.assertEqual(parse_transform("DummyPass,LowLatency22"),
|
||||||
@ -967,3 +974,297 @@ class TransformChecker(unittest.TestCase):
|
|||||||
def test_check_dummy_pass_is_available(self, available_transformations):
|
def test_check_dummy_pass_is_available(self, available_transformations):
|
||||||
available_transformations.return_value = {"LowLatency2": None}
|
available_transformations.return_value = {"LowLatency2": None}
|
||||||
self.assertRaises(Error, check_available_transforms, [("DummyPass", "")])
|
self.assertRaises(Error, check_available_transforms, [("DummyPass", "")])
|
||||||
|
|
||||||
|
|
||||||
|
class TestLayoutParsing(unittest.TestCase):
|
||||||
|
def test_get_layout_1(self):
|
||||||
|
argv_layout = "name1([n,h,w,c]),name2([n,h,w,c]->[n,c,h,w])"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_2(self):
|
||||||
|
argv_layout = "name1(nhwc),name2(nhwc->nchw)"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_3(self):
|
||||||
|
argv_layout = "name1(n...c),name2(n...c->nc...)"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'n...c', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': 'n...c', 'target_layout': 'nc...'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_4(self):
|
||||||
|
argv_layout = "nhwc"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_5(self):
|
||||||
|
argv_layout = "[n,h,w,c]"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_6(self):
|
||||||
|
argv_layout = "nhwc->nchw"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_7(self):
|
||||||
|
argv_layout = "[n,h,w,c]->[n,c,h,w]"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_scalar(self):
|
||||||
|
argv_layout = "name1(nhwc),name2([])"
|
||||||
|
result = get_layout_values(argv_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': '[]', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_layout_1(self):
|
||||||
|
argv_source_layout = "[n,h,w,c]"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_layout_2(self):
|
||||||
|
argv_source_layout = "nhwc"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_layout_3(self):
|
||||||
|
argv_source_layout = "name1(nhwc),name2(nchw)"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': 'nchw', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_layout_4(self):
|
||||||
|
argv_source_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_layout_5(self):
|
||||||
|
argv_source_layout = "name1(nhwc),name2([n,c,h,w])"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_layout_6(self):
|
||||||
|
argv_source_layout = "name1(nhwc),name2[n,c,h,w]"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': '[n,c,h,w]', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_layout_scalar(self):
|
||||||
|
argv_source_layout = "name1(nhwc),name2([])"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': None},
|
||||||
|
'name2': {'source_layout': '[]', 'target_layout': None}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_target_layout_1(self):
|
||||||
|
argv_target_layout = "[n,h,w,c]"
|
||||||
|
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'': {'source_layout': None, 'target_layout': '[n,h,w,c]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_target_layout_2(self):
|
||||||
|
argv_target_layout = "nhwc"
|
||||||
|
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'': {'source_layout': None, 'target_layout': 'nhwc'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_target_layout_3(self):
|
||||||
|
argv_target_layout = "name1(nhwc),name2(nchw)"
|
||||||
|
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||||
|
'name2': {'source_layout': None, 'target_layout': 'nchw'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_target_layout_4(self):
|
||||||
|
argv_target_layout = "name1([n,h,w,c]),name2([n,c,h,w])"
|
||||||
|
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': None, 'target_layout': '[n,h,w,c]'},
|
||||||
|
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_target_layout_5(self):
|
||||||
|
argv_target_layout = "name1(nhwc),name2([n,c,h,w])"
|
||||||
|
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||||
|
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_target_layout_6(self):
|
||||||
|
argv_target_layout = "name1(nhwc),name2[n,c,h,w]"
|
||||||
|
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||||
|
'name2': {'source_layout': None, 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_target_layout_scalar(self):
|
||||||
|
argv_target_layout = "name1(nhwc),name2[]"
|
||||||
|
result = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': None, 'target_layout': 'nhwc'},
|
||||||
|
'name2': {'source_layout': None, 'target_layout': '[]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_target_layout_1(self):
|
||||||
|
argv_source_layout = "[n,h,w,c]"
|
||||||
|
argv_target_layout = "[n,c,h,w]"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_target_layout_2(self):
|
||||||
|
argv_source_layout = "nhwc"
|
||||||
|
argv_target_layout = "nchw"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_target_layout_3(self):
|
||||||
|
argv_source_layout = "name1(nhwc),name2(nhwc)"
|
||||||
|
argv_target_layout = "name1(nchw),name2(nchw)"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
|
||||||
|
'name2': {'source_layout': 'nhwc', 'target_layout': 'nchw'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_target_layout_4(self):
|
||||||
|
argv_source_layout = "name1([n,h,w,c]),name2([n,h,w,c])"
|
||||||
|
argv_target_layout = "name1([n,c,h,w]),name2([n,c,h,w])"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'},
|
||||||
|
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_target_layout_5(self):
|
||||||
|
argv_source_layout = "name1(nhwc),name2[n,h,w,c]"
|
||||||
|
argv_target_layout = "name1(nchw),name2[n,c,h,w]"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
|
||||||
|
'name2': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_source_target_layout_scalar(self):
|
||||||
|
argv_source_layout = "name1(nhwc),name2[]"
|
||||||
|
argv_target_layout = "name1(nchw),name2[]"
|
||||||
|
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||||
|
exp_res = {'name1': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
|
||||||
|
'name2': {'source_layout': '[]', 'target_layout': '[]'}}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_layout_raises_if_layout_and_source_layout_provided(self):
|
||||||
|
argv_layout = "nhwc"
|
||||||
|
argv_source_layout = "nhwc"
|
||||||
|
with self.assertRaises(Error):
|
||||||
|
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout)
|
||||||
|
|
||||||
|
def test_get_layout_raises_if_layout_and_target_layout_provided(self):
|
||||||
|
argv_layout = "nhwc->nchw"
|
||||||
|
argv_target_layout = "nchw"
|
||||||
|
with self.assertRaises(Error):
|
||||||
|
get_layout_values(argv_layout=argv_layout, argv_target_layout=argv_target_layout)
|
||||||
|
|
||||||
|
def test_get_layout_raises_if_layout_with_source_and_target_layout_provided(self):
|
||||||
|
argv_layout = "nhwc->nchw"
|
||||||
|
argv_source_layout = "nhwc"
|
||||||
|
argv_target_layout = "nchw"
|
||||||
|
with self.assertRaises(Error):
|
||||||
|
get_layout_values(argv_layout=argv_layout, argv_source_layout=argv_source_layout,
|
||||||
|
argv_target_layout=argv_target_layout)
|
||||||
|
|
||||||
|
def test_get_layout_raises_incorrect_format(self):
|
||||||
|
argv_layout = "name[n,h,w,c]->nchw"
|
||||||
|
with self.assertRaises(Error):
|
||||||
|
res = get_layout_values(argv_layout=argv_layout)
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
def test_get_layout_raises_multiple_layouts_without_names(self):
|
||||||
|
argv_layout = "nhwc->nchw,nhwc->nchw"
|
||||||
|
with self.assertRaises(Error):
|
||||||
|
res = get_layout_values(argv_layout=argv_layout)
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
def test_get_layout_raises_multiple_layouts_without_names_source_layout(self):
|
||||||
|
argv_source_layout = "nhwc,nhwc"
|
||||||
|
with self.assertRaises(Error):
|
||||||
|
res = get_layout_values(argv_source_layout=argv_source_layout)
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
def test_get_layout_raises_multiple_layouts_without_names_target_layout(self):
|
||||||
|
argv_target_layout = "nchw,nchw"
|
||||||
|
with self.assertRaises(Error):
|
||||||
|
res = get_layout_values(argv_target_layout=argv_target_layout)
|
||||||
|
print(res)
|
||||||
|
Loading…
Reference in New Issue
Block a user