diff --git a/src/bindings/python/src/openvino/frontend/tensorflow/utils.py b/src/bindings/python/src/openvino/frontend/tensorflow/utils.py index f4ac5b0a121..b75f371d0c1 100644 --- a/src/bindings/python/src/openvino/frontend/tensorflow/utils.py +++ b/src/bindings/python/src/openvino/frontend/tensorflow/utils.py @@ -6,9 +6,11 @@ import logging as log -import numpy as np import sys from distutils.version import LooseVersion +from typing import List, Dict, Union + +import numpy as np from openvino.runtime import PartialShape, Dimension @@ -264,10 +266,22 @@ def trace_tf_model(model, input_shapes, input_types, example_input): input_needs_packing = True + def are_shapes_defined(shape: Union[List, Dict]): + if shape is None: + return False + assert hasattr(shape, '__len__') + if len(shape) == 0: + return False + + if isinstance(shape, list): + return np.all([shape is not None for shape in input_shapes]) + elif isinstance(shape, dict): + return np.all([shape is not None for name, shape in input_shapes.items()]) + if example_input is not None: concrete_func = get_concrete_func(tf_function, example_input, input_needs_packing, "Could not trace the TF model with the following error: {}") - elif input_shapes is not None: + elif are_shapes_defined(input_shapes): inp = create_example_input_by_user_shapes(input_shapes, input_types) concrete_func = get_concrete_func(tf_function, inp, input_needs_packing, "Could not trace the TF model with the following error: {}") diff --git a/tests/layer_tests/ovc_python_api_tests/test_pytorch.py b/tests/layer_tests/ovc_python_api_tests/test_pytorch.py index 268f69d13f0..52d66690459 100644 --- a/tests/layer_tests/ovc_python_api_tests/test_pytorch.py +++ b/tests/layer_tests/ovc_python_api_tests/test_pytorch.py @@ -187,8 +187,7 @@ def create_pytorch_nn_module_case3(tmp_dir): sample_input1 = torch.zeros(1, 3, 10, 10) sample_input2 = torch.zeros(1, 3, 10, 10) sample_input = tuple([sample_input1, sample_input2]) - - return pt_model, ref_model, {'input': "[?,3,?,?],[?,3,?,?]", + return pt_model, ref_model, {'input': [[-1, 3, -1, -1], [-1, 3, -1, -1]], 'example_input': sample_input} @@ -1093,6 +1092,37 @@ class ConvertRaises(unittest.TestCase): with self.assertRaisesRegex(TypeError, ".*got an unexpected keyword argument 'example_inputs'.*"): convert_model(pytorch_model, example_inputs=(torch.tensor(1),)) + def test_incorrect_inputs_1(self): + from openvino.tools.ovc import convert_model + pytorch_model, _, _ = create_pytorch_nn_module_case1('') + + with self.assertRaisesRegex(Exception, ".*No node with name.*"): + convert_model(pytorch_model, input='input1[1, 10]') + + def test_incorrect_inputs_2(self): + from openvino.tools.ovc import convert_model + pytorch_model, _, _ = create_pytorch_nn_module_case1('') + + # check that it accepts specified names as is, without parsing into 2 different inputs + with self.assertRaisesRegex(Exception, 'No node with name input1,input2'): + convert_model(pytorch_model, input='input1,input2') + + def test_incorrect_inputs_3(self): + from openvino.tools.ovc import convert_model + pytorch_model, _, _ = create_pytorch_nn_module_case1('') + + # check that it accepts specified names as is, without parsing into 2 different inputs + with self.assertRaisesRegex(Exception, 'No node with name input1\[1, 10\],input2\[2, 100\]'): + convert_model(pytorch_model, input='input1[1, 10],input2[2, 100]') + + def test_incorrect_inputs_4(self): + from openvino.tools.ovc import convert_model + pytorch_model, _, _ = create_pytorch_nn_module_case1('') + + # check that it accepts specified names as is, without parsing into 2 different inputs + with self.assertRaisesRegex(Exception, 'No node with name input1\[1, 10\]'): + convert_model(pytorch_model, input=['input1[1, 10]', 'input2[2, 100]']) + def test_failed_extension(self): from openvino.tools.ovc import convert_model from openvino.frontend.pytorch import ConversionExtension diff --git a/tests/layer_tests/ovc_python_api_tests/test_tf.py b/tests/layer_tests/ovc_python_api_tests/test_tf.py index a311666eb83..e8c098ff9ca 100644 --- a/tests/layer_tests/ovc_python_api_tests/test_tf.py +++ b/tests/layer_tests/ovc_python_api_tests/test_tf.py @@ -938,3 +938,20 @@ class TestTFLoadByModel(unittest.TestCase): fe = fem.load_by_model(model) assert fe is not None assert fe.get_name() == "tf" + + +class TestTFConvertRaises(unittest.TestCase): + def test_incorrect_inputs_1(self): + from openvino.tools.ovc import convert_model + tf_model, _, _ = create_keras_model('') + + with self.assertRaisesRegex(Exception, ".*No node with name.*"): + convert_model(tf_model, input='Input1[1, 2, 3]') + + def test_incorrect_inputs_2(self): + from openvino.tools.ovc import convert_model + tf_model, _, _ = create_keras_model('') + + # check that it accepts specified names as is without parsing into 2 different inputs + with self.assertRaisesRegex(Exception, 'No node with name Input1\[1, 2, 3\],Input2\[1, 2, 3\]'): + convert_model(tf_model, input='Input1[1, 2, 3],Input2[1, 2, 3]') diff --git a/tools/ovc/openvino/tools/ovc/cli_parser.py b/tools/ovc/openvino/tools/ovc/cli_parser.py index 8568c12a47a..9931fa539cb 100644 --- a/tools/ovc/openvino/tools/ovc/cli_parser.py +++ b/tools/ovc/openvino/tools/ovc/cli_parser.py @@ -44,11 +44,8 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type, :return: InputCutInfo """ if isinstance(input, str): - # Parse params from string - node_name, shape = parse_input_value(input) # pylint: disable=no-member - return _InputCutInfo(node_name, - PartialShape(shape) if shape is not None else None) + return _InputCutInfo(input, None) if isinstance(input, (tuple, list)) or is_shape_type(input): # If input represents list with shape, wrap it to list. Single PartialShape also goes to this condition. # Check of all dimensions will be in is_shape_type(val) method below @@ -68,10 +65,10 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type, if inp_type is not None: raise Exception("More than one input type provided: {}".format(input)) inp_type = val - elif is_shape_type(val): + elif is_shape_type(val) or val is None: if shape is not None: raise Exception("More than one input shape provided: {}".format(input)) - shape = PartialShape(val) + shape = PartialShape(val) if val is not None else None else: raise Exception("Incorrect input parameters provided. Expected tuple with input name, " "input type or input shape. Got unknown object: {}".format(val)) @@ -116,7 +113,20 @@ def is_single_input(input: [tuple, list]): return True -def input_to_input_cut_info(input: [str, tuple, list]): +def parse_inputs(inputs: str): + inputs_list = [] + # Split to list of string + for input_value in split_inputs(inputs): + + # Parse string with parameters for single input + node_name, shape = parse_input_value(input_value) + # pylint: disable=no-member + inputs_list.append((node_name, shape)) + return inputs_list + + +def input_to_input_cut_info(input: [dict, tuple, list]): + """ Parses 'input' to list of InputCutInfo. :param input: input cut parameters passed by user @@ -124,18 +134,10 @@ def input_to_input_cut_info(input: [str, tuple, list]): """ if input is None: return [] - if isinstance(input, str): - inputs = [] - # Split to list of string - for input_value in split_inputs(input): - # Parse string with parameters for single input - node_name, shape = parse_input_value(input_value) - # pylint: disable=no-member - inputs.append(_InputCutInfo(node_name, - PartialShape(shape) if shape is not None else None)) - return inputs if isinstance(input, (tuple, list)): + if len(input) == 0: + return [] # Case when input is single shape set in tuple if len(input) > 0 and isinstance(input[0], (int, Dimension)): input = [input] diff --git a/tools/ovc/openvino/tools/ovc/convert_impl.py b/tools/ovc/openvino/tools/ovc/convert_impl.py index fabf3601a49..cf09a2abfe2 100644 --- a/tools/ovc/openvino/tools/ovc/convert_impl.py +++ b/tools/ovc/openvino/tools/ovc/convert_impl.py @@ -22,7 +22,7 @@ from openvino.tools.ovc.moc_frontend.pipeline import moc_pipeline from openvino.tools.ovc.moc_frontend.moc_emit_ir import moc_emit_ir from openvino.tools.ovc.convert_data_type import destination_type_to_np_data_type from openvino.tools.ovc.cli_parser import get_available_front_ends, get_common_cli_options, depersonalize, \ - get_mo_convert_params, input_to_input_cut_info + get_mo_convert_params, input_to_input_cut_info, parse_inputs from openvino.tools.ovc.help import get_convert_model_help_specifics from openvino.tools.ovc.error import Error, FrameworkError @@ -93,7 +93,14 @@ def arguments_post_parsing(argv: argparse.Namespace): if is_verbose(argv): print_argv(argv) - params_parsing(argv) + import re + if argv.is_python_api_used and isinstance(argv.input, str): + argv.input = [argv.input] + + if not argv.is_python_api_used and isinstance(argv.input, str): + argv.input = parse_inputs(argv.input) + + normalize_inputs(argv) log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes)) if not hasattr(argv, 'output') or argv.output is None: @@ -297,9 +304,9 @@ def input_model_is_object(input_model): return True -def params_parsing(argv: argparse.Namespace): +def normalize_inputs(argv: argparse.Namespace): """ - Parses params passed to convert_model and wraps resulting values into dictionaries or lists. + repacks params passed to convert_model and wraps resulting values into dictionaries or lists. After working of this method following values are set in argv: argv.input, argv.inputs_list - list of input names. Both values are used in some parts of MO. diff --git a/tools/ovc/unit_tests/moc_tf_fe/conversion_basic_models_test.py b/tools/ovc/unit_tests/moc_tf_fe/conversion_basic_models_test.py index ba2a1a8a652..d1306523b83 100644 --- a/tools/ovc/unit_tests/moc_tf_fe/conversion_basic_models_test.py +++ b/tools/ovc/unit_tests/moc_tf_fe/conversion_basic_models_test.py @@ -144,14 +144,14 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase): # new frontend ( "model_add_with_undefined_constant.pbtxt", - "x[2,3]", + ("x", [2, 3]), {"x": np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32)}, np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32), np.float32 ), ( "model_mul_with_undefined_constant.pbtxt", - "x[2]", + ("x", [2]), {"x": np.array([11, -12], dtype=np.int32)}, np.array([0, 0], dtype=np.int32), np.int32 diff --git a/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py b/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py index c7ecf0e90b5..8d7edc8b214 100644 --- a/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py +++ b/tools/ovc/unit_tests/ovc/utils/cli_parser_test.py @@ -13,7 +13,7 @@ from unittest.mock import patch import numpy as np from openvino.tools.ovc.cli_parser import input_to_input_cut_info, check_positive, writable_dir, \ - readable_file_or_object, get_all_cli_parser, get_mo_convert_params + readable_file_or_object, get_all_cli_parser, get_mo_convert_params, parse_inputs from openvino.tools.ovc.convert_impl import pack_params_to_args_namespace, arguments_post_parsing, args_to_argv from openvino.tools.ovc.error import Error from unit_tests.ovc.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry @@ -26,6 +26,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_shapes2(self): # shapes specified using --input command line parameter and no values argv_input = "inp1[1,22,333,123],inp2[-1,45,7,1]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([1,22,333,123])), _InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))] @@ -33,30 +34,31 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_raises_get_shapes_1(self): argv_input = "[h,y]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_raises_get_shapes_2(self): argv_input = "(2, 3)" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_raises_get_shapes_3(self): argv_input = "input_1(2, 3)" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_raises_get_shapes_4(self): argv_input = "(2, 3),(10, 10)" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_raises_get_shapes_5(self): argv_input = "<2,3,4>" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_raises_get_shapes_6(self): argv_input = "sd<2,3>" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_get_shapes_complex_input(self): argv_input = "[10, -1, 100],mask[],[?,?]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])), _InputCutInfo(name='mask', shape=PartialShape([])), @@ -66,6 +68,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self): # shapes and value for freezing specified using --input command line parameter argv_input = "inp1,inp2" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1'), _InputCutInfo(name='inp2')] @@ -75,6 +78,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_and_freezing_with_scalar(self): # shapes and value for freezing specified using --input command line parameter argv_input = "inp1,inp2[]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1'), _InputCutInfo(name='inp2', shape=PartialShape([]))] @@ -83,6 +87,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_shapes3(self): # shapes and value for freezing specified using --input command line parameter argv_input = "inp1[3 1],inp2[3,2,3],inp3[5]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape([3,2,3])), @@ -92,6 +97,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_shapes3_comma_sep(self): # shapes and value for freezing specified using --input command line parameter argv_input = "inp1[3 1],inp2[3 2 3],inp3[5]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape([3,2,3])), @@ -101,6 +107,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_shapes6(self): # 0D value for freezing specified using --input command line parameter without shape argv_input = "inp1[3,1],inp2[3,2,3],inp3" + argv_input = parse_inputs(argv_input) inputs_list, result, _ = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), @@ -111,6 +118,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_shapes7(self): # 0D shape and value for freezing specified using --input command line parameter argv_input = "inp1[3,1],inp2[3,2,3],inp3[]" + argv_input = parse_inputs(argv_input) inputs_list, result, _ = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), @@ -121,6 +129,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_and_data_types_shape_only(self): argv_input = "placeholder1[3 1],placeholder2,placeholder3" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='placeholder1', shape=PartialShape([3,1])), _InputCutInfo(name='placeholder2'), @@ -129,6 +138,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_and_data_types_shape_with_ports_only(self): argv_input = "placeholder1:4[3 1],placeholder2,2:placeholder3" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='placeholder1:4', shape=PartialShape([3,1])), _InputCutInfo(name='placeholder2'), @@ -137,15 +147,16 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_wrong_data_types(self): argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{abracadabra},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_shape_and_value_shape_mismatch(self): # size of value tensor does not correspond to specified shape for the third node argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3],inp3[5 3]->[2.0 3.0 5.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_get_shapes_no_input_no_shape(self): argv_input = "" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [] self.assertEqual(inputs, inputs_ref) @@ -153,18 +164,21 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_no_input_one_shape2(self): argv_input = "[12,4,1]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1]))] self.assertEqual(inputs, inputs_ref) def test_get_shapes_for_scalar_inputs(self): argv_input = "[]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(shape=PartialShape([]))] self.assertEqual(inputs, inputs_ref) def test_get_shapes_two_input_shapes_with_scalar(self): argv_input = "[12,4,1],[]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])), _InputCutInfo(shape=PartialShape([]))] @@ -172,6 +186,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_two_input_shapes(self): argv_input = "[12,4,1],[10]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])), _InputCutInfo(shape=PartialShape([10])),] @@ -179,6 +194,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_one_input_no_shape(self): argv_input = "inp1" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1')] self.assertEqual(inputs, inputs_ref) @@ -186,6 +202,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_partial_shapes2(self): # shapes specified using --input command line parameter and no values argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")), _InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))] @@ -194,6 +211,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_partial_shapes3(self): # shapes and value for freezing specified using --input command line parameter argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), @@ -203,6 +221,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_partial_shapes6(self): # 0D value for freezing specified using --input command line parameter without shape argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), @@ -212,6 +231,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_partial_shapes7(self): # 0D shape and value for freezing specified using --input command line parameter argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3[]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), @@ -220,15 +240,16 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_partial_shapes_freeze_dynamic_negative_case1(self): argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_partial_shapes_freeze_dynamic_negative_case2(self): argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_get_shapes_several_inputs_several_partial_shapes2_comma_separator(self): # shapes specified using --input command line parameter and no values argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")), _InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))] @@ -237,6 +258,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_partial_shapes3_comma_separator(self): # shapes and value for freezing specified using --input command line parameter argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), @@ -246,6 +268,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_partial_shapes6_comma_separator(self): # 0D value for freezing specified using --input command line parameter without shape argv_input = "inp1[3, 1],inp2[3.., ..2, 5..10, ?,-1],inp3" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), @@ -255,6 +278,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_get_shapes_several_inputs_several_partial_shapes7_comma_separator(self): # 0D shape and value for freezing specified using --input command line parameter argv_input = "inp1[3,1],inp2[3.., ..2,5..10, ?,-1],inp3[]" + argv_input = parse_inputs(argv_input) inputs = input_to_input_cut_info(argv_input) inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), @@ -263,24 +287,25 @@ class TestShapesParsing(UnitTestWithMockedTelemetry): def test_partial_shapes_freeze_dynamic_negative_case1_comma_separator(self): argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_partial_shapes_freeze_dynamic_negative_case2_comma_separator(self): argv_input = "inp1:1[1,2,-1]->[1.0 2.0 3.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_partial_shapes_freeze_dynamic_negative_case3_comma_separator(self): argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_partial_shapes_freeze_dynamic_negative_case4_comma_separator(self): argv_input = "inp1:1[1, 2, -1]->[1.0 2.0 3.0]" - self.assertRaises(Error, input_to_input_cut_info, argv_input) + self.assertRaises(Error, parse_inputs, argv_input) def test_not_supported_arrow(self): with self.assertRaisesRegex(Exception, "Incorrect format of input."): - input_to_input_cut_info("inp1->[1.0]") + argv_input = parse_inputs("inp1->[1.0]") + input_to_input_cut_info(argv_input) class PositiveChecker(unittest.TestCase): @@ -526,4 +551,3 @@ class TestConvertModelParamsParsing(unittest.TestCase): assert param_name not in cli_parser._option_string_actions else: assert param_name in cli_parser._option_string_actions -