Fixed layout parsing. (#14157)

* Fixed layout parsing.

* Small fix.

* Test fixed.
This commit is contained in:
Anastasiia Pnevskaia
2022-11-22 19:06:22 +01:00
committed by GitHub
parent 580c0c6b90
commit 83e41d5d88
2 changed files with 12 additions and 2 deletions

View File

@@ -1679,10 +1679,10 @@ def parse_layouts_by_destination(s: str, parsed: dict, parsed_list: list, dest:
else:
for idx, layout_str in enumerate(list_s):
# case for: "name1(nhwc->[n,c,h,w])"
p1 = re.compile(r'(\w*)\((\S+)\)')
p1 = re.compile(r'([\w.:/\\]*)\((\S+)\)')
m1 = p1.match(layout_str)
# case for: "name1[n,h,w,c]->[n,c,h,w]"
p2 = re.compile(r'(\w*)(\[\S*\])')
p2 = re.compile(r'([\w.:/\\]*)(\[\S*\])')
m2 = p2.match(layout_str)
if m1:
found_g = m1.groups()

View File

@@ -1561,6 +1561,16 @@ class TestLayoutParsing(unittest.TestCase):
for i in exp_res.keys():
assert np.array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_6(self):
argv_source_layout = "name1.0:a/b(nhwc),name2\\d\\[n,h,w,c]"
argv_target_layout = "name1.0:a/b(nchw),name2\\d\\[n,c,h,w]"
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
exp_res = {'name1.0:a/b': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
'name2\\d\\': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
assert np.array_equal(result[i], exp_res[i])
def test_get_layout_source_target_layout_scalar(self):
argv_source_layout = "name1(nhwc),name2[]"
argv_target_layout = "name1(nchw),name2[]"