[MO] add support for scalar shapes into cli_parser.py (#18312)
* add support for scalar shapes into cli_parser.py * add test-case with scalar shapes for convert_model * reordered inputs in test-case with scalar shapes for convert_model * minor clarifications --------- Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
af9a8cbbd7
commit
1cb4595727
@ -160,6 +160,18 @@ def create_pytorch_nn_module_case2(tmp_dir):
|
||||
'example_input': sample_input}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_with_scalar_input(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
ref_model = make_ref_pt_model_two_inputs([[], [-1, 3, -1, -1]])
|
||||
|
||||
sample_input1 = torch.tensor(0.66)
|
||||
sample_input2 = torch.zeros(1, 3, 10, 10)
|
||||
sample_input = sample_input1, sample_input2
|
||||
|
||||
return pt_model, ref_model, {'input_shape': ["[]", PartialShape([-1, 3, -1, -1])],
|
||||
'example_input': sample_input}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_case3(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
ref_model = make_ref_pt_model_two_inputs([-1, 3, -1, -1])
|
||||
@ -710,7 +722,8 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
create_pytorch_module_with_optional_inputs_case2,
|
||||
create_pytorch_module_with_optional_inputs_case3,
|
||||
create_pytorch_module_with_optional_inputs_case4,
|
||||
create_pytorch_module_with_optional_inputs_case5
|
||||
create_pytorch_module_with_optional_inputs_case5,
|
||||
create_pytorch_nn_module_with_scalar_input,
|
||||
]
|
||||
|
||||
@ pytest.mark.parametrize("create_model", test_data)
|
||||
|
@ -1492,7 +1492,7 @@ def split_inputs(input_str):
|
||||
|
||||
def split_shapes(argv_input_shape: str):
|
||||
range_reg = r'([0-9]*\.\.[0-9]*)'
|
||||
first_digit_reg = r'([0-9 ]+|-1|\?|{})'.format(range_reg)
|
||||
first_digit_reg = r'([0-9 ]*|-1|\?|{})'.format(range_reg)
|
||||
next_digits_reg = r'(,{})*'.format(first_digit_reg)
|
||||
tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg,
|
||||
first_digit_reg, next_digits_reg)
|
||||
@ -1500,7 +1500,7 @@ def split_shapes(argv_input_shape: str):
|
||||
full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg)
|
||||
if not re.match(full_reg, argv_input_shape):
|
||||
raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape)
|
||||
return re.findall(r'[(\[]([0-9,\.\? -]+)[)\]]', argv_input_shape)
|
||||
return re.findall(r'[(\[]([0-9,\.\? -]*)[)\]]', argv_input_shape)
|
||||
|
||||
def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None):
|
||||
"""
|
||||
@ -1581,7 +1581,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
# clean inputs from values for freezing
|
||||
inputs_without_value = list(map(lambda x: x.split('->')[0], inputs))
|
||||
placeholder_shapes = dict(zip_longest(inputs_without_value,
|
||||
map(lambda x: PartialShape(x) if x else None, shapes)))
|
||||
map(lambda x: PartialShape(x) if x is not None else None, shapes)))
|
||||
for inp in inputs:
|
||||
if '->' not in inp:
|
||||
inputs_list.append(inp)
|
||||
|
@ -772,6 +772,39 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
exp_res = np.array([12, 4, 1])
|
||||
assert np.array_equal(result, exp_res)
|
||||
|
||||
def test_get_shapes_for_scalar_inputs(self):
|
||||
argv_input = ""
|
||||
input_shapes = "[]"
|
||||
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
ref_result = np.array([])
|
||||
assert np.array_equal(result, ref_result)
|
||||
|
||||
def test_get_shapes_two_input_shapes_with_scalar(self):
|
||||
argv_input = ""
|
||||
input_shapes = "[12,4,1],[]"
|
||||
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
ref_result = [np.array([12, 4, 1]), np.array([])]
|
||||
for shape, ref_shape in zip(result, ref_result):
|
||||
assert np.array_equal(shape, ref_shape)
|
||||
|
||||
def test_get_shapes_two_input_shapes(self):
|
||||
argv_input = ""
|
||||
input_shapes = "[12,4,1],[10]"
|
||||
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
ref_result = [np.array([12, 4, 1]), np.array([10])]
|
||||
for shape, ref_shape in zip(result, ref_result):
|
||||
assert np.array_equal(shape, ref_shape)
|
||||
|
||||
def test_get_shapes_two_named_input_shapes_with_scalar(self):
|
||||
argv_input = "inp1,inp2"
|
||||
input_shapes = "[12,4,1],[]"
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
|
||||
exp_res = {'inp1': np.array([12, 4, 1]), 'inp2': np.array([])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_one_input_no_shape(self):
|
||||
argv_input = "inp1"
|
||||
|
Loading…
Reference in New Issue
Block a user