Fixed parsing of 'layout' param (#16999)

* Fixed layout parsing.

* Small correction.

* Removed wrong change.
This commit is contained in:
Anastasiia Pnevskaia
2023-04-18 20:43:38 +02:00
committed by GitHub
parent e93c8e1b1c
commit 078f28911b
2 changed files with 21 additions and 2 deletions

View File

@@ -1354,10 +1354,10 @@ def parse_layouts_by_destination(s: str, parsed: dict, parsed_list: list, dest:
else:
for idx, layout_str in enumerate(list_s):
# case for: "name1(nhwc->[n,c,h,w])"
p1 = re.compile(r'([\w.:/\\]*)\((\S+)\)')
p1 = re.compile(r'([^\[\]\(\)]*)\((\S+)\)')
m1 = p1.match(layout_str)
# case for: "name1[n,h,w,c]->[n,c,h,w]"
p2 = re.compile(r'([\w.:/\\]*)(\[\S*\])')
p2 = re.compile(r'([^\[\]\(\)]*)(\[\S*\])')
m2 = p2.match(layout_str)
if m1:
found_g = m1.groups()

View File

@@ -1386,6 +1386,15 @@ class TestLayoutParsing(unittest.TestCase):
for i in exp_res.keys():
assert np.array_equal(result[i], exp_res[i])
def test_get_layout_8(self):
argv_layout = "name1-0(n...c),name2-0(n...c->nc...)"
result = get_layout_values(argv_layout)
exp_res = {'name1-0': {'source_layout': 'n...c', 'target_layout': None},
'name2-0': {'source_layout': 'n...c', 'target_layout': 'nc...'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
assert np.array_equal(result[i], exp_res[i])
def test_get_layout_scalar(self):
argv_layout = "name1(nhwc),name2([])"
result = get_layout_values(argv_layout)
@@ -1575,6 +1584,16 @@ class TestLayoutParsing(unittest.TestCase):
for i in exp_res.keys():
assert np.array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_7(self):
argv_source_layout = "name1-0[n,h,w,c],name2-1(?c??)"
argv_target_layout = "name1-0(nchw),name2-1[?,?,?,c]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1-0': {'source_layout': '[n,h,w,c]', 'target_layout': 'nchw'},
'name2-1': {'source_layout': '?c??', 'target_layout': '[?,?,?,c]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
assert np.array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_scalar(self):
argv_source_layout = "name1(nhwc),name2[]"
argv_target_layout = "name1(nchw),name2[]"