[ovc] check if input is correct in split_inputs (#19350)
This commit is contained in:
parent
02d6c1cb5d
commit
9a1726a419
@ -609,6 +609,11 @@ def parse_input_value(input_value: str):
|
|||||||
|
|
||||||
|
|
||||||
def split_inputs(input_str):
|
def split_inputs(input_str):
|
||||||
|
pattern = r'^(?:[^[\]()<]*(\[[\.)-9,\-\s?]*\])*,)*[^[\]()<]*(\[[\.0-9,\-\s?]*\])*$'
|
||||||
|
if not re.match(pattern, input_str):
|
||||||
|
raise Error(f"input value '{input_str}' is incorrect. Input should be in the following format: "
|
||||||
|
f"{get_convert_model_help_specifics()['input']['description']}")
|
||||||
|
|
||||||
brakets_count = 0
|
brakets_count = 0
|
||||||
inputs = []
|
inputs = []
|
||||||
while input_str:
|
while input_str:
|
||||||
|
@ -15,7 +15,7 @@ def convert_model(
|
|||||||
input_model: [str, pathlib.Path, Any, list], # TODO: Instead of list just accept arbitrary number of positional arguments
|
input_model: [str, pathlib.Path, Any, list], # TODO: Instead of list just accept arbitrary number of positional arguments
|
||||||
|
|
||||||
# Framework-agnostic parameters
|
# Framework-agnostic parameters
|
||||||
input: [list, dict] = None,
|
input: [list, dict, str] = None,
|
||||||
output: [str, list] = None,
|
output: [str, list] = None,
|
||||||
example_input: Any = None,
|
example_input: Any = None,
|
||||||
extension: [str, pathlib.Path, list, Any] = None,
|
extension: [str, pathlib.Path, list, Any] = None,
|
||||||
|
@ -31,6 +31,38 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
|||||||
_InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))]
|
_InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))]
|
||||||
self.assertEqual(inputs, inputs_ref)
|
self.assertEqual(inputs, inputs_ref)
|
||||||
|
|
||||||
|
def test_raises_get_shapes_1(self):
|
||||||
|
argv_input = "[h,y]"
|
||||||
|
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||||
|
|
||||||
|
def test_raises_get_shapes_2(self):
|
||||||
|
argv_input = "(2, 3)"
|
||||||
|
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||||
|
|
||||||
|
def test_raises_get_shapes_3(self):
|
||||||
|
argv_input = "input_1(2, 3)"
|
||||||
|
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||||
|
|
||||||
|
def test_raises_get_shapes_4(self):
|
||||||
|
argv_input = "(2, 3),(10, 10)"
|
||||||
|
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||||
|
|
||||||
|
def test_raises_get_shapes_5(self):
|
||||||
|
argv_input = "<2,3,4>"
|
||||||
|
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||||
|
|
||||||
|
def test_raises_get_shapes_6(self):
|
||||||
|
argv_input = "sd<2,3>"
|
||||||
|
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||||
|
|
||||||
|
def test_get_shapes_complex_input(self):
|
||||||
|
argv_input = "[10, -1, 100],mask[],[?,?]"
|
||||||
|
inputs = input_to_input_cut_info(argv_input)
|
||||||
|
inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])),
|
||||||
|
_InputCutInfo(name='mask', shape=PartialShape([])),
|
||||||
|
_InputCutInfo(shape=PartialShape([-1, -1]))]
|
||||||
|
self.assertEqual(inputs, inputs_ref)
|
||||||
|
|
||||||
def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
|
def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
|
||||||
# shapes and value for freezing specified using --input command line parameter
|
# shapes and value for freezing specified using --input command line parameter
|
||||||
argv_input = "inp1,inp2"
|
argv_input = "inp1,inp2"
|
||||||
|
Loading…
Reference in New Issue
Block a user