[ovc] check if input is correct in split_inputs (#19350)

This commit is contained in:
Pavel Esir 2023-08-30 15:13:43 +02:00 committed by GitHub
parent 02d6c1cb5d
commit 9a1726a419
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 1 deletions

View File

@ -609,6 +609,11 @@ def parse_input_value(input_value: str):
def split_inputs(input_str):
pattern = r'^(?:[^[\]()<]*(\[[\.)-9,\-\s?]*\])*,)*[^[\]()<]*(\[[\.0-9,\-\s?]*\])*$'
if not re.match(pattern, input_str):
raise Error(f"input value '{input_str}' is incorrect. Input should be in the following format: "
f"{get_convert_model_help_specifics()['input']['description']}")
brakets_count = 0
inputs = []
while input_str:

View File

@ -15,7 +15,7 @@ def convert_model(
input_model: [str, pathlib.Path, Any, list], # TODO: Instead of list just accept arbitrary number of positional arguments
# Framework-agnostic parameters
input: [list, dict] = None,
input: [list, dict, str] = None,
output: [str, list] = None,
example_input: Any = None,
extension: [str, pathlib.Path, list, Any] = None,

View File

@ -31,6 +31,38 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
_InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))]
self.assertEqual(inputs, inputs_ref)
def test_raises_get_shapes_1(self):
argv_input = "[h,y]"
self.assertRaises(Error, input_to_input_cut_info, argv_input)
def test_raises_get_shapes_2(self):
argv_input = "(2, 3)"
self.assertRaises(Error, input_to_input_cut_info, argv_input)
def test_raises_get_shapes_3(self):
argv_input = "input_1(2, 3)"
self.assertRaises(Error, input_to_input_cut_info, argv_input)
def test_raises_get_shapes_4(self):
argv_input = "(2, 3),(10, 10)"
self.assertRaises(Error, input_to_input_cut_info, argv_input)
def test_raises_get_shapes_5(self):
argv_input = "<2,3,4>"
self.assertRaises(Error, input_to_input_cut_info, argv_input)
def test_raises_get_shapes_6(self):
argv_input = "sd<2,3>"
self.assertRaises(Error, input_to_input_cut_info, argv_input)
def test_get_shapes_complex_input(self):
argv_input = "[10, -1, 100],mask[],[?,?]"
inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])),
_InputCutInfo(name='mask', shape=PartialShape([])),
_InputCutInfo(shape=PartialShape([-1, -1]))]
self.assertEqual(inputs, inputs_ref)
def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
# shapes and value for freezing specified using --input command line parameter
argv_input = "inp1,inp2"