[ovc] check if input is correct in split_inputs (#19350)
This commit is contained in:
parent
02d6c1cb5d
commit
9a1726a419
@ -609,6 +609,11 @@ def parse_input_value(input_value: str):
|
||||
|
||||
|
||||
def split_inputs(input_str):
|
||||
pattern = r'^(?:[^[\]()<]*(\[[\.)-9,\-\s?]*\])*,)*[^[\]()<]*(\[[\.0-9,\-\s?]*\])*$'
|
||||
if not re.match(pattern, input_str):
|
||||
raise Error(f"input value '{input_str}' is incorrect. Input should be in the following format: "
|
||||
f"{get_convert_model_help_specifics()['input']['description']}")
|
||||
|
||||
brakets_count = 0
|
||||
inputs = []
|
||||
while input_str:
|
||||
|
@ -15,7 +15,7 @@ def convert_model(
|
||||
input_model: [str, pathlib.Path, Any, list], # TODO: Instead of list just accept arbitrary number of positional arguments
|
||||
|
||||
# Framework-agnostic parameters
|
||||
input: [list, dict] = None,
|
||||
input: [list, dict, str] = None,
|
||||
output: [str, list] = None,
|
||||
example_input: Any = None,
|
||||
extension: [str, pathlib.Path, list, Any] = None,
|
||||
|
@ -31,6 +31,38 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_raises_get_shapes_1(self):
|
||||
argv_input = "[h,y]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
|
||||
def test_raises_get_shapes_2(self):
|
||||
argv_input = "(2, 3)"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
|
||||
def test_raises_get_shapes_3(self):
|
||||
argv_input = "input_1(2, 3)"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
|
||||
def test_raises_get_shapes_4(self):
|
||||
argv_input = "(2, 3),(10, 10)"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
|
||||
def test_raises_get_shapes_5(self):
|
||||
argv_input = "<2,3,4>"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
|
||||
def test_raises_get_shapes_6(self):
|
||||
argv_input = "sd<2,3>"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
|
||||
def test_get_shapes_complex_input(self):
|
||||
argv_input = "[10, -1, 100],mask[],[?,?]"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])),
|
||||
_InputCutInfo(name='mask', shape=PartialShape([])),
|
||||
_InputCutInfo(shape=PartialShape([-1, -1]))]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1,inp2"
|
||||
|
Loading…
Reference in New Issue
Block a user