Fixed parsing of 'layout' param (#16999)
* Fixed layout parsing. * Small correction. * Removed wrong change.
This commit is contained in:
committed by
GitHub
parent
e93c8e1b1c
commit
078f28911b
@@ -1354,10 +1354,10 @@ def parse_layouts_by_destination(s: str, parsed: dict, parsed_list: list, dest:
|
||||
else:
|
||||
for idx, layout_str in enumerate(list_s):
|
||||
# case for: "name1(nhwc->[n,c,h,w])"
|
||||
p1 = re.compile(r'([\w.:/\\]*)\((\S+)\)')
|
||||
p1 = re.compile(r'([^\[\]\(\)]*)\((\S+)\)')
|
||||
m1 = p1.match(layout_str)
|
||||
# case for: "name1[n,h,w,c]->[n,c,h,w]"
|
||||
p2 = re.compile(r'([\w.:/\\]*)(\[\S*\])')
|
||||
p2 = re.compile(r'([^\[\]\(\)]*)(\[\S*\])')
|
||||
m2 = p2.match(layout_str)
|
||||
if m1:
|
||||
found_g = m1.groups()
|
||||
|
||||
@@ -1386,6 +1386,15 @@ class TestLayoutParsing(unittest.TestCase):
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_8(self):
|
||||
argv_layout = "name1-0(n...c),name2-0(n...c->nc...)"
|
||||
result = get_layout_values(argv_layout)
|
||||
exp_res = {'name1-0': {'source_layout': 'n...c', 'target_layout': None},
|
||||
'name2-0': {'source_layout': 'n...c', 'target_layout': 'nc...'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_scalar(self):
|
||||
argv_layout = "name1(nhwc),name2([])"
|
||||
result = get_layout_values(argv_layout)
|
||||
@@ -1575,6 +1584,16 @@ class TestLayoutParsing(unittest.TestCase):
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_7(self):
|
||||
argv_source_layout = "name1-0[n,h,w,c],name2-1(?c??)"
|
||||
argv_target_layout = "name1-0(nchw),name2-1[?,?,?,c]"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1-0': {'source_layout': '[n,h,w,c]', 'target_layout': 'nchw'},
|
||||
'name2-1': {'source_layout': '?c??', 'target_layout': '[?,?,?,c]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_scalar(self):
|
||||
argv_source_layout = "name1(nhwc),name2[]"
|
||||
argv_target_layout = "name1(nchw),name2[]"
|
||||
|
||||
Reference in New Issue
Block a user