diff --git a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py index b519a0bc599..10a54d3ceba 100644 --- a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py +++ b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py @@ -160,6 +160,18 @@ def create_pytorch_nn_module_case2(tmp_dir): 'example_input': sample_input} +def create_pytorch_nn_module_with_scalar_input(tmp_dir): + pt_model = make_pt_model_two_inputs() + ref_model = make_ref_pt_model_two_inputs([[], [-1, 3, -1, -1]]) + + sample_input1 = torch.tensor(0.66) + sample_input2 = torch.zeros(1, 3, 10, 10) + sample_input = sample_input1, sample_input2 + + return pt_model, ref_model, {'input_shape': ["[]", PartialShape([-1, 3, -1, -1])], + 'example_input': sample_input} + + def create_pytorch_nn_module_case3(tmp_dir): pt_model = make_pt_model_two_inputs() ref_model = make_ref_pt_model_two_inputs([-1, 3, -1, -1]) @@ -710,7 +722,8 @@ class TestMoConvertPyTorch(CommonMOConvertTest): create_pytorch_module_with_optional_inputs_case2, create_pytorch_module_with_optional_inputs_case3, create_pytorch_module_with_optional_inputs_case4, - create_pytorch_module_with_optional_inputs_case5 + create_pytorch_module_with_optional_inputs_case5, + create_pytorch_nn_module_with_scalar_input, ] @ pytest.mark.parametrize("create_model", test_data) diff --git a/tools/ovc/openvino/tools/ovc/cli_parser.py b/tools/ovc/openvino/tools/ovc/cli_parser.py index a3d7790182a..86d0281b135 100644 --- a/tools/ovc/openvino/tools/ovc/cli_parser.py +++ b/tools/ovc/openvino/tools/ovc/cli_parser.py @@ -1492,7 +1492,7 @@ def split_inputs(input_str): def split_shapes(argv_input_shape: str): range_reg = r'([0-9]*\.\.[0-9]*)' - first_digit_reg = r'([0-9 ]+|-1|\?|{})'.format(range_reg) + first_digit_reg = r'([0-9 ]*|-1|\?|{})'.format(range_reg) next_digits_reg = r'(,{})*'.format(first_digit_reg) tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg, first_digit_reg, next_digits_reg) @@ -1500,7 +1500,7 @@ def split_shapes(argv_input_shape: str): full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg) if not re.match(full_reg, argv_input_shape): raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape) - return re.findall(r'[(\[]([0-9,\.\? -]+)[)\]]', argv_input_shape) + return re.findall(r'[(\[]([0-9,\.\? -]*)[)\]]', argv_input_shape) def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None): """ @@ -1581,7 +1581,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No # clean inputs from values for freezing inputs_without_value = list(map(lambda x: x.split('->')[0], inputs)) placeholder_shapes = dict(zip_longest(inputs_without_value, - map(lambda x: PartialShape(x) if x else None, shapes))) + map(lambda x: PartialShape(x) if x is not None else None, shapes))) for inp in inputs: if '->' not in inp: inputs_list.append(inp) diff --git a/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py b/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py index d0e4d03dcdc..0f9aef0ce67 100644 --- a/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py +++ b/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py @@ -772,6 +772,39 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): exp_res = np.array([12, 4, 1]) assert np.array_equal(result, exp_res) + def test_get_shapes_for_scalar_inputs(self): + argv_input = "" + input_shapes = "[]" + _, result, _ = get_placeholder_shapes(argv_input, input_shapes) + ref_result = np.array([]) + assert np.array_equal(result, ref_result) + + def test_get_shapes_two_input_shapes_with_scalar(self): + argv_input = "" + input_shapes = "[12,4,1],[]" + _, result, _ = get_placeholder_shapes(argv_input, input_shapes) + ref_result = [np.array([12, 4, 1]), np.array([])] + for shape, ref_shape in zip(result, ref_result): + assert np.array_equal(shape, ref_shape) + + def test_get_shapes_two_input_shapes(self): + argv_input = "" + input_shapes = "[12,4,1],[10]" + _, result, _ = get_placeholder_shapes(argv_input, input_shapes) + ref_result = [np.array([12, 4, 1]), np.array([10])] + for shape, ref_shape in zip(result, ref_result): + assert np.array_equal(shape, ref_shape) + + def test_get_shapes_two_named_input_shapes_with_scalar(self): + argv_input = "inp1,inp2" + input_shapes = "[12,4,1],[]" + inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes) + + exp_res = {'inp1': np.array([12, 4, 1]), 'inp2': np.array([])} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + self.assertEqual(inputs_list, ["inp1","inp2"]) + for i in exp_res.keys(): + assert np.array_equal(result[i], exp_res[i]) def test_get_shapes_one_input_no_shape(self): argv_input = "inp1"