[MO] add support for scalar shapes into cli_parser.py (#18312)

* add support for scalar shapes into cli_parser.py

* add test-case with scalar shapes for convert_model

* reordered inputs in test-case with scalar shapes for convert_model

* minor clarifications

---------

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Pavel Esir 2023-07-08 19:27:07 +02:00 committed by GitHub
parent af9a8cbbd7
commit 1cb4595727
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 4 deletions

View File

@ -160,6 +160,18 @@ def create_pytorch_nn_module_case2(tmp_dir):
'example_input': sample_input}
def create_pytorch_nn_module_with_scalar_input(tmp_dir):
pt_model = make_pt_model_two_inputs()
ref_model = make_ref_pt_model_two_inputs([[], [-1, 3, -1, -1]])
sample_input1 = torch.tensor(0.66)
sample_input2 = torch.zeros(1, 3, 10, 10)
sample_input = sample_input1, sample_input2
return pt_model, ref_model, {'input_shape': ["[]", PartialShape([-1, 3, -1, -1])],
'example_input': sample_input}
def create_pytorch_nn_module_case3(tmp_dir):
pt_model = make_pt_model_two_inputs()
ref_model = make_ref_pt_model_two_inputs([-1, 3, -1, -1])
@ -710,7 +722,8 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
create_pytorch_module_with_optional_inputs_case2,
create_pytorch_module_with_optional_inputs_case3,
create_pytorch_module_with_optional_inputs_case4,
create_pytorch_module_with_optional_inputs_case5
create_pytorch_module_with_optional_inputs_case5,
create_pytorch_nn_module_with_scalar_input,
]
@ pytest.mark.parametrize("create_model", test_data)

View File

@ -1492,7 +1492,7 @@ def split_inputs(input_str):
def split_shapes(argv_input_shape: str):
range_reg = r'([0-9]*\.\.[0-9]*)'
first_digit_reg = r'([0-9 ]+|-1|\?|{})'.format(range_reg)
first_digit_reg = r'([0-9 ]*|-1|\?|{})'.format(range_reg)
next_digits_reg = r'(,{})*'.format(first_digit_reg)
tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg,
first_digit_reg, next_digits_reg)
@ -1500,7 +1500,7 @@ def split_shapes(argv_input_shape: str):
full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg)
if not re.match(full_reg, argv_input_shape):
raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape)
return re.findall(r'[(\[]([0-9,\.\? -]+)[)\]]', argv_input_shape)
return re.findall(r'[(\[]([0-9,\.\? -]*)[)\]]', argv_input_shape)
def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None):
"""
@ -1581,7 +1581,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
# clean inputs from values for freezing
inputs_without_value = list(map(lambda x: x.split('->')[0], inputs))
placeholder_shapes = dict(zip_longest(inputs_without_value,
map(lambda x: PartialShape(x) if x else None, shapes)))
map(lambda x: PartialShape(x) if x is not None else None, shapes)))
for inp in inputs:
if '->' not in inp:
inputs_list.append(inp)

View File

@ -772,6 +772,39 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
exp_res = np.array([12, 4, 1])
assert np.array_equal(result, exp_res)
def test_get_shapes_for_scalar_inputs(self):
argv_input = ""
input_shapes = "[]"
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
ref_result = np.array([])
assert np.array_equal(result, ref_result)
def test_get_shapes_two_input_shapes_with_scalar(self):
argv_input = ""
input_shapes = "[12,4,1],[]"
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
ref_result = [np.array([12, 4, 1]), np.array([])]
for shape, ref_shape in zip(result, ref_result):
assert np.array_equal(shape, ref_shape)
def test_get_shapes_two_input_shapes(self):
argv_input = ""
input_shapes = "[12,4,1],[10]"
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
ref_result = [np.array([12, 4, 1]), np.array([10])]
for shape, ref_shape in zip(result, ref_result):
assert np.array_equal(shape, ref_shape)
def test_get_shapes_two_named_input_shapes_with_scalar(self):
argv_input = "inp1,inp2"
input_shapes = "[12,4,1],[]"
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
exp_res = {'inp1': np.array([12, 4, 1]), 'inp2': np.array([])}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
self.assertEqual(inputs_list, ["inp1","inp2"])
for i in exp_res.keys():
assert np.array_equal(result[i], exp_res[i])
def test_get_shapes_one_input_no_shape(self):
argv_input = "inp1"