Fixed layout parsing. (#14157)
* Fixed layout parsing. * Small fix. * Test fixed.
This commit is contained in:
committed by
GitHub
parent
580c0c6b90
commit
83e41d5d88
@@ -1679,10 +1679,10 @@ def parse_layouts_by_destination(s: str, parsed: dict, parsed_list: list, dest:
|
||||
else:
|
||||
for idx, layout_str in enumerate(list_s):
|
||||
# case for: "name1(nhwc->[n,c,h,w])"
|
||||
p1 = re.compile(r'(\w*)\((\S+)\)')
|
||||
p1 = re.compile(r'([\w.:/\\]*)\((\S+)\)')
|
||||
m1 = p1.match(layout_str)
|
||||
# case for: "name1[n,h,w,c]->[n,c,h,w]"
|
||||
p2 = re.compile(r'(\w*)(\[\S*\])')
|
||||
p2 = re.compile(r'([\w.:/\\]*)(\[\S*\])')
|
||||
m2 = p2.match(layout_str)
|
||||
if m1:
|
||||
found_g = m1.groups()
|
||||
|
||||
@@ -1561,6 +1561,16 @@ class TestLayoutParsing(unittest.TestCase):
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_6(self):
|
||||
argv_source_layout = "name1.0:a/b(nhwc),name2\\d\\[n,h,w,c]"
|
||||
argv_target_layout = "name1.0:a/b(nchw),name2\\d\\[n,c,h,w]"
|
||||
result = get_layout_values(argv_source_layout=argv_source_layout, argv_target_layout=argv_target_layout)
|
||||
exp_res = {'name1.0:a/b': {'source_layout': 'nhwc', 'target_layout': 'nchw'},
|
||||
'name2\\d\\': {'source_layout': '[n,h,w,c]', 'target_layout': '[n,c,h,w]'}}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_layout_source_target_layout_scalar(self):
|
||||
argv_source_layout = "name1(nhwc),name2[]"
|
||||
argv_target_layout = "name1(nchw),name2[]"
|
||||
|
||||
Reference in New Issue
Block a user