[OVC] do not parse inputs for py_api (#19742)

* [OVC] do not parse inputs

* fix unit-tests

* remove redundant lines, add test case

* add one more unit-test

* skip None values

* replace str with List in test_mo_import_from_memory

* corrected type hints, added a safety assert
This commit is contained in:
Pavel Esir 2023-09-22 13:05:21 +02:00 committed by GitHub
parent c1271d1217
commit 9271b79540
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 138 additions and 44 deletions

View File

@ -6,9 +6,11 @@
import logging as log import logging as log
import numpy as np
import sys import sys
from distutils.version import LooseVersion from distutils.version import LooseVersion
from typing import List, Dict, Union
import numpy as np
from openvino.runtime import PartialShape, Dimension from openvino.runtime import PartialShape, Dimension
@ -264,10 +266,22 @@ def trace_tf_model(model, input_shapes, input_types, example_input):
input_needs_packing = True input_needs_packing = True
def are_shapes_defined(shape: Union[List, Dict]):
if shape is None:
return False
assert hasattr(shape, '__len__')
if len(shape) == 0:
return False
if isinstance(shape, list):
return np.all([shape is not None for shape in input_shapes])
elif isinstance(shape, dict):
return np.all([shape is not None for name, shape in input_shapes.items()])
if example_input is not None: if example_input is not None:
concrete_func = get_concrete_func(tf_function, example_input, input_needs_packing, concrete_func = get_concrete_func(tf_function, example_input, input_needs_packing,
"Could not trace the TF model with the following error: {}") "Could not trace the TF model with the following error: {}")
elif input_shapes is not None: elif are_shapes_defined(input_shapes):
inp = create_example_input_by_user_shapes(input_shapes, input_types) inp = create_example_input_by_user_shapes(input_shapes, input_types)
concrete_func = get_concrete_func(tf_function, inp, input_needs_packing, concrete_func = get_concrete_func(tf_function, inp, input_needs_packing,
"Could not trace the TF model with the following error: {}") "Could not trace the TF model with the following error: {}")

View File

@ -187,8 +187,7 @@ def create_pytorch_nn_module_case3(tmp_dir):
sample_input1 = torch.zeros(1, 3, 10, 10) sample_input1 = torch.zeros(1, 3, 10, 10)
sample_input2 = torch.zeros(1, 3, 10, 10) sample_input2 = torch.zeros(1, 3, 10, 10)
sample_input = tuple([sample_input1, sample_input2]) sample_input = tuple([sample_input1, sample_input2])
return pt_model, ref_model, {'input': [[-1, 3, -1, -1], [-1, 3, -1, -1]],
return pt_model, ref_model, {'input': "[?,3,?,?],[?,3,?,?]",
'example_input': sample_input} 'example_input': sample_input}
@ -1093,6 +1092,37 @@ class ConvertRaises(unittest.TestCase):
with self.assertRaisesRegex(TypeError, ".*got an unexpected keyword argument 'example_inputs'.*"): with self.assertRaisesRegex(TypeError, ".*got an unexpected keyword argument 'example_inputs'.*"):
convert_model(pytorch_model, example_inputs=(torch.tensor(1),)) convert_model(pytorch_model, example_inputs=(torch.tensor(1),))
def test_incorrect_inputs_1(self):
from openvino.tools.ovc import convert_model
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
with self.assertRaisesRegex(Exception, ".*No node with name.*"):
convert_model(pytorch_model, input='input1[1, 10]')
def test_incorrect_inputs_2(self):
from openvino.tools.ovc import convert_model
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
# check that it accepts specified names as is, without parsing into 2 different inputs
with self.assertRaisesRegex(Exception, 'No node with name input1,input2'):
convert_model(pytorch_model, input='input1,input2')
def test_incorrect_inputs_3(self):
from openvino.tools.ovc import convert_model
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
# check that it accepts specified names as is, without parsing into 2 different inputs
with self.assertRaisesRegex(Exception, 'No node with name input1\[1, 10\],input2\[2, 100\]'):
convert_model(pytorch_model, input='input1[1, 10],input2[2, 100]')
def test_incorrect_inputs_4(self):
from openvino.tools.ovc import convert_model
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
# check that it accepts specified names as is, without parsing into 2 different inputs
with self.assertRaisesRegex(Exception, 'No node with name input1\[1, 10\]'):
convert_model(pytorch_model, input=['input1[1, 10]', 'input2[2, 100]'])
def test_failed_extension(self): def test_failed_extension(self):
from openvino.tools.ovc import convert_model from openvino.tools.ovc import convert_model
from openvino.frontend.pytorch import ConversionExtension from openvino.frontend.pytorch import ConversionExtension

View File

@ -938,3 +938,20 @@ class TestTFLoadByModel(unittest.TestCase):
fe = fem.load_by_model(model) fe = fem.load_by_model(model)
assert fe is not None assert fe is not None
assert fe.get_name() == "tf" assert fe.get_name() == "tf"
class TestTFConvertRaises(unittest.TestCase):
def test_incorrect_inputs_1(self):
from openvino.tools.ovc import convert_model
tf_model, _, _ = create_keras_model('')
with self.assertRaisesRegex(Exception, ".*No node with name.*"):
convert_model(tf_model, input='Input1[1, 2, 3]')
def test_incorrect_inputs_2(self):
from openvino.tools.ovc import convert_model
tf_model, _, _ = create_keras_model('')
# check that it accepts specified names as is without parsing into 2 different inputs
with self.assertRaisesRegex(Exception, 'No node with name Input1\[1, 2, 3\],Input2\[1, 2, 3\]'):
convert_model(tf_model, input='Input1[1, 2, 3],Input2[1, 2, 3]')

View File

@ -44,11 +44,8 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type,
:return: InputCutInfo :return: InputCutInfo
""" """
if isinstance(input, str): if isinstance(input, str):
# Parse params from string
node_name, shape = parse_input_value(input)
# pylint: disable=no-member # pylint: disable=no-member
return _InputCutInfo(node_name, return _InputCutInfo(input, None)
PartialShape(shape) if shape is not None else None)
if isinstance(input, (tuple, list)) or is_shape_type(input): if isinstance(input, (tuple, list)) or is_shape_type(input):
# If input represents list with shape, wrap it to list. Single PartialShape also goes to this condition. # If input represents list with shape, wrap it to list. Single PartialShape also goes to this condition.
# Check of all dimensions will be in is_shape_type(val) method below # Check of all dimensions will be in is_shape_type(val) method below
@ -68,10 +65,10 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type,
if inp_type is not None: if inp_type is not None:
raise Exception("More than one input type provided: {}".format(input)) raise Exception("More than one input type provided: {}".format(input))
inp_type = val inp_type = val
elif is_shape_type(val): elif is_shape_type(val) or val is None:
if shape is not None: if shape is not None:
raise Exception("More than one input shape provided: {}".format(input)) raise Exception("More than one input shape provided: {}".format(input))
shape = PartialShape(val) shape = PartialShape(val) if val is not None else None
else: else:
raise Exception("Incorrect input parameters provided. Expected tuple with input name, " raise Exception("Incorrect input parameters provided. Expected tuple with input name, "
"input type or input shape. Got unknown object: {}".format(val)) "input type or input shape. Got unknown object: {}".format(val))
@ -116,7 +113,20 @@ def is_single_input(input: [tuple, list]):
return True return True
def input_to_input_cut_info(input: [str, tuple, list]): def parse_inputs(inputs: str):
inputs_list = []
# Split to list of string
for input_value in split_inputs(inputs):
# Parse string with parameters for single input
node_name, shape = parse_input_value(input_value)
# pylint: disable=no-member
inputs_list.append((node_name, shape))
return inputs_list
def input_to_input_cut_info(input: [dict, tuple, list]):
""" """
Parses 'input' to list of InputCutInfo. Parses 'input' to list of InputCutInfo.
:param input: input cut parameters passed by user :param input: input cut parameters passed by user
@ -124,18 +134,10 @@ def input_to_input_cut_info(input: [str, tuple, list]):
""" """
if input is None: if input is None:
return [] return []
if isinstance(input, str):
inputs = []
# Split to list of string
for input_value in split_inputs(input):
# Parse string with parameters for single input
node_name, shape = parse_input_value(input_value)
# pylint: disable=no-member
inputs.append(_InputCutInfo(node_name,
PartialShape(shape) if shape is not None else None))
return inputs
if isinstance(input, (tuple, list)): if isinstance(input, (tuple, list)):
if len(input) == 0:
return []
# Case when input is single shape set in tuple # Case when input is single shape set in tuple
if len(input) > 0 and isinstance(input[0], (int, Dimension)): if len(input) > 0 and isinstance(input[0], (int, Dimension)):
input = [input] input = [input]

View File

@ -22,7 +22,7 @@ from openvino.tools.ovc.moc_frontend.pipeline import moc_pipeline
from openvino.tools.ovc.moc_frontend.moc_emit_ir import moc_emit_ir from openvino.tools.ovc.moc_frontend.moc_emit_ir import moc_emit_ir
from openvino.tools.ovc.convert_data_type import destination_type_to_np_data_type from openvino.tools.ovc.convert_data_type import destination_type_to_np_data_type
from openvino.tools.ovc.cli_parser import get_available_front_ends, get_common_cli_options, depersonalize, \ from openvino.tools.ovc.cli_parser import get_available_front_ends, get_common_cli_options, depersonalize, \
get_mo_convert_params, input_to_input_cut_info get_mo_convert_params, input_to_input_cut_info, parse_inputs
from openvino.tools.ovc.help import get_convert_model_help_specifics from openvino.tools.ovc.help import get_convert_model_help_specifics
from openvino.tools.ovc.error import Error, FrameworkError from openvino.tools.ovc.error import Error, FrameworkError
@ -93,7 +93,14 @@ def arguments_post_parsing(argv: argparse.Namespace):
if is_verbose(argv): if is_verbose(argv):
print_argv(argv) print_argv(argv)
params_parsing(argv) import re
if argv.is_python_api_used and isinstance(argv.input, str):
argv.input = [argv.input]
if not argv.is_python_api_used and isinstance(argv.input, str):
argv.input = parse_inputs(argv.input)
normalize_inputs(argv)
log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes)) log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes))
if not hasattr(argv, 'output') or argv.output is None: if not hasattr(argv, 'output') or argv.output is None:
@ -297,9 +304,9 @@ def input_model_is_object(input_model):
return True return True
def params_parsing(argv: argparse.Namespace): def normalize_inputs(argv: argparse.Namespace):
""" """
Parses params passed to convert_model and wraps resulting values into dictionaries or lists. repacks params passed to convert_model and wraps resulting values into dictionaries or lists.
After working of this method following values are set in argv: After working of this method following values are set in argv:
argv.input, argv.inputs_list - list of input names. Both values are used in some parts of MO. argv.input, argv.inputs_list - list of input names. Both values are used in some parts of MO.

View File

@ -144,14 +144,14 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
# new frontend # new frontend
( (
"model_add_with_undefined_constant.pbtxt", "model_add_with_undefined_constant.pbtxt",
"x[2,3]", ("x", [2, 3]),
{"x": np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32)}, {"x": np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32)},
np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32), np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32),
np.float32 np.float32
), ),
( (
"model_mul_with_undefined_constant.pbtxt", "model_mul_with_undefined_constant.pbtxt",
"x[2]", ("x", [2]),
{"x": np.array([11, -12], dtype=np.int32)}, {"x": np.array([11, -12], dtype=np.int32)},
np.array([0, 0], dtype=np.int32), np.array([0, 0], dtype=np.int32),
np.int32 np.int32

View File

@ -13,7 +13,7 @@ from unittest.mock import patch
import numpy as np import numpy as np
from openvino.tools.ovc.cli_parser import input_to_input_cut_info, check_positive, writable_dir, \ from openvino.tools.ovc.cli_parser import input_to_input_cut_info, check_positive, writable_dir, \
readable_file_or_object, get_all_cli_parser, get_mo_convert_params readable_file_or_object, get_all_cli_parser, get_mo_convert_params, parse_inputs
from openvino.tools.ovc.convert_impl import pack_params_to_args_namespace, arguments_post_parsing, args_to_argv from openvino.tools.ovc.convert_impl import pack_params_to_args_namespace, arguments_post_parsing, args_to_argv
from openvino.tools.ovc.error import Error from openvino.tools.ovc.error import Error
from unit_tests.ovc.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry from unit_tests.ovc.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry
@ -26,6 +26,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_shapes2(self): def test_get_shapes_several_inputs_several_shapes2(self):
# shapes specified using --input command line parameter and no values # shapes specified using --input command line parameter and no values
argv_input = "inp1[1,22,333,123],inp2[-1,45,7,1]" argv_input = "inp1[1,22,333,123],inp2[-1,45,7,1]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([1,22,333,123])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([1,22,333,123])),
_InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))] _InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))]
@ -33,30 +34,31 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_raises_get_shapes_1(self): def test_raises_get_shapes_1(self):
argv_input = "[h,y]" argv_input = "[h,y]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_raises_get_shapes_2(self): def test_raises_get_shapes_2(self):
argv_input = "(2, 3)" argv_input = "(2, 3)"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_raises_get_shapes_3(self): def test_raises_get_shapes_3(self):
argv_input = "input_1(2, 3)" argv_input = "input_1(2, 3)"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_raises_get_shapes_4(self): def test_raises_get_shapes_4(self):
argv_input = "(2, 3),(10, 10)" argv_input = "(2, 3),(10, 10)"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_raises_get_shapes_5(self): def test_raises_get_shapes_5(self):
argv_input = "<2,3,4>" argv_input = "<2,3,4>"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_raises_get_shapes_6(self): def test_raises_get_shapes_6(self):
argv_input = "sd<2,3>" argv_input = "sd<2,3>"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_get_shapes_complex_input(self): def test_get_shapes_complex_input(self):
argv_input = "[10, -1, 100],mask[],[?,?]" argv_input = "[10, -1, 100],mask[],[?,?]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])), inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])),
_InputCutInfo(name='mask', shape=PartialShape([])), _InputCutInfo(name='mask', shape=PartialShape([])),
@ -66,6 +68,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self): def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
# shapes and value for freezing specified using --input command line parameter # shapes and value for freezing specified using --input command line parameter
argv_input = "inp1,inp2" argv_input = "inp1,inp2"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1'), inputs_ref = [_InputCutInfo(name='inp1'),
_InputCutInfo(name='inp2')] _InputCutInfo(name='inp2')]
@ -75,6 +78,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_and_freezing_with_scalar(self): def test_get_shapes_and_freezing_with_scalar(self):
# shapes and value for freezing specified using --input command line parameter # shapes and value for freezing specified using --input command line parameter
argv_input = "inp1,inp2[]" argv_input = "inp1,inp2[]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1'), inputs_ref = [_InputCutInfo(name='inp1'),
_InputCutInfo(name='inp2', shape=PartialShape([]))] _InputCutInfo(name='inp2', shape=PartialShape([]))]
@ -83,6 +87,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_shapes3(self): def test_get_shapes_several_inputs_several_shapes3(self):
# shapes and value for freezing specified using --input command line parameter # shapes and value for freezing specified using --input command line parameter
argv_input = "inp1[3 1],inp2[3,2,3],inp3[5]" argv_input = "inp1[3 1],inp2[3,2,3],inp3[5]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3])), _InputCutInfo(name='inp2', shape=PartialShape([3,2,3])),
@ -92,6 +97,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_shapes3_comma_sep(self): def test_get_shapes_several_inputs_several_shapes3_comma_sep(self):
# shapes and value for freezing specified using --input command line parameter # shapes and value for freezing specified using --input command line parameter
argv_input = "inp1[3 1],inp2[3 2 3],inp3[5]" argv_input = "inp1[3 1],inp2[3 2 3],inp3[5]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3])), _InputCutInfo(name='inp2', shape=PartialShape([3,2,3])),
@ -101,6 +107,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_shapes6(self): def test_get_shapes_several_inputs_several_shapes6(self):
# 0D value for freezing specified using --input command line parameter without shape # 0D value for freezing specified using --input command line parameter without shape
argv_input = "inp1[3,1],inp2[3,2,3],inp3" argv_input = "inp1[3,1],inp2[3,2,3],inp3"
argv_input = parse_inputs(argv_input)
inputs_list, result, _ = input_to_input_cut_info(argv_input) inputs_list, result, _ = input_to_input_cut_info(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
@ -111,6 +118,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_shapes7(self): def test_get_shapes_several_inputs_several_shapes7(self):
# 0D shape and value for freezing specified using --input command line parameter # 0D shape and value for freezing specified using --input command line parameter
argv_input = "inp1[3,1],inp2[3,2,3],inp3[]" argv_input = "inp1[3,1],inp2[3,2,3],inp3[]"
argv_input = parse_inputs(argv_input)
inputs_list, result, _ = input_to_input_cut_info(argv_input) inputs_list, result, _ = input_to_input_cut_info(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
@ -121,6 +129,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_and_data_types_shape_only(self): def test_get_shapes_and_data_types_shape_only(self):
argv_input = "placeholder1[3 1],placeholder2,placeholder3" argv_input = "placeholder1[3 1],placeholder2,placeholder3"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='placeholder1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='placeholder1', shape=PartialShape([3,1])),
_InputCutInfo(name='placeholder2'), _InputCutInfo(name='placeholder2'),
@ -129,6 +138,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_and_data_types_shape_with_ports_only(self): def test_get_shapes_and_data_types_shape_with_ports_only(self):
argv_input = "placeholder1:4[3 1],placeholder2,2:placeholder3" argv_input = "placeholder1:4[3 1],placeholder2,2:placeholder3"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='placeholder1:4', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='placeholder1:4', shape=PartialShape([3,1])),
_InputCutInfo(name='placeholder2'), _InputCutInfo(name='placeholder2'),
@ -137,15 +147,16 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_wrong_data_types(self): def test_wrong_data_types(self):
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{abracadabra},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]" argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{abracadabra},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_shape_and_value_shape_mismatch(self): def test_shape_and_value_shape_mismatch(self):
# size of value tensor does not correspond to specified shape for the third node # size of value tensor does not correspond to specified shape for the third node
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3],inp3[5 3]->[2.0 3.0 5.0]" argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3],inp3[5 3]->[2.0 3.0 5.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_get_shapes_no_input_no_shape(self): def test_get_shapes_no_input_no_shape(self):
argv_input = "" argv_input = ""
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [] inputs_ref = []
self.assertEqual(inputs, inputs_ref) self.assertEqual(inputs, inputs_ref)
@ -153,18 +164,21 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_no_input_one_shape2(self): def test_get_shapes_no_input_one_shape2(self):
argv_input = "[12,4,1]" argv_input = "[12,4,1]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1]))] inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1]))]
self.assertEqual(inputs, inputs_ref) self.assertEqual(inputs, inputs_ref)
def test_get_shapes_for_scalar_inputs(self): def test_get_shapes_for_scalar_inputs(self):
argv_input = "[]" argv_input = "[]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(shape=PartialShape([]))] inputs_ref = [_InputCutInfo(shape=PartialShape([]))]
self.assertEqual(inputs, inputs_ref) self.assertEqual(inputs, inputs_ref)
def test_get_shapes_two_input_shapes_with_scalar(self): def test_get_shapes_two_input_shapes_with_scalar(self):
argv_input = "[12,4,1],[]" argv_input = "[12,4,1],[]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])), inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])),
_InputCutInfo(shape=PartialShape([]))] _InputCutInfo(shape=PartialShape([]))]
@ -172,6 +186,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_two_input_shapes(self): def test_get_shapes_two_input_shapes(self):
argv_input = "[12,4,1],[10]" argv_input = "[12,4,1],[10]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])), inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])),
_InputCutInfo(shape=PartialShape([10])),] _InputCutInfo(shape=PartialShape([10])),]
@ -179,6 +194,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_one_input_no_shape(self): def test_get_shapes_one_input_no_shape(self):
argv_input = "inp1" argv_input = "inp1"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1')] inputs_ref = [_InputCutInfo(name='inp1')]
self.assertEqual(inputs, inputs_ref) self.assertEqual(inputs, inputs_ref)
@ -186,6 +202,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_partial_shapes2(self): def test_get_shapes_several_inputs_several_partial_shapes2(self):
# shapes specified using --input command line parameter and no values # shapes specified using --input command line parameter and no values
argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]" argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")),
_InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))] _InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))]
@ -194,6 +211,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_partial_shapes3(self): def test_get_shapes_several_inputs_several_partial_shapes3(self):
# shapes and value for freezing specified using --input command line parameter # shapes and value for freezing specified using --input command line parameter
argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]" argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
@ -203,6 +221,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_partial_shapes6(self): def test_get_shapes_several_inputs_several_partial_shapes6(self):
# 0D value for freezing specified using --input command line parameter without shape # 0D value for freezing specified using --input command line parameter without shape
argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3" argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
@ -212,6 +231,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_partial_shapes7(self): def test_get_shapes_several_inputs_several_partial_shapes7(self):
# 0D shape and value for freezing specified using --input command line parameter # 0D shape and value for freezing specified using --input command line parameter
argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3[]" argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3[]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
@ -220,15 +240,16 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_partial_shapes_freeze_dynamic_negative_case1(self): def test_partial_shapes_freeze_dynamic_negative_case1(self):
argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]" argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_partial_shapes_freeze_dynamic_negative_case2(self): def test_partial_shapes_freeze_dynamic_negative_case2(self):
argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]" argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_get_shapes_several_inputs_several_partial_shapes2_comma_separator(self): def test_get_shapes_several_inputs_several_partial_shapes2_comma_separator(self):
# shapes specified using --input command line parameter and no values # shapes specified using --input command line parameter and no values
argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]" argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")),
_InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))] _InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))]
@ -237,6 +258,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_partial_shapes3_comma_separator(self): def test_get_shapes_several_inputs_several_partial_shapes3_comma_separator(self):
# shapes and value for freezing specified using --input command line parameter # shapes and value for freezing specified using --input command line parameter
argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]" argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
@ -246,6 +268,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_partial_shapes6_comma_separator(self): def test_get_shapes_several_inputs_several_partial_shapes6_comma_separator(self):
# 0D value for freezing specified using --input command line parameter without shape # 0D value for freezing specified using --input command line parameter without shape
argv_input = "inp1[3, 1],inp2[3.., ..2, 5..10, ?,-1],inp3" argv_input = "inp1[3, 1],inp2[3.., ..2, 5..10, ?,-1],inp3"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
@ -255,6 +278,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_get_shapes_several_inputs_several_partial_shapes7_comma_separator(self): def test_get_shapes_several_inputs_several_partial_shapes7_comma_separator(self):
# 0D shape and value for freezing specified using --input command line parameter # 0D shape and value for freezing specified using --input command line parameter
argv_input = "inp1[3,1],inp2[3.., ..2,5..10, ?,-1],inp3[]" argv_input = "inp1[3,1],inp2[3.., ..2,5..10, ?,-1],inp3[]"
argv_input = parse_inputs(argv_input)
inputs = input_to_input_cut_info(argv_input) inputs = input_to_input_cut_info(argv_input)
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])), inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")), _InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
@ -263,24 +287,25 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
def test_partial_shapes_freeze_dynamic_negative_case1_comma_separator(self): def test_partial_shapes_freeze_dynamic_negative_case1_comma_separator(self):
argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]" argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_partial_shapes_freeze_dynamic_negative_case2_comma_separator(self): def test_partial_shapes_freeze_dynamic_negative_case2_comma_separator(self):
argv_input = "inp1:1[1,2,-1]->[1.0 2.0 3.0]" argv_input = "inp1:1[1,2,-1]->[1.0 2.0 3.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_partial_shapes_freeze_dynamic_negative_case3_comma_separator(self): def test_partial_shapes_freeze_dynamic_negative_case3_comma_separator(self):
argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]" argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_partial_shapes_freeze_dynamic_negative_case4_comma_separator(self): def test_partial_shapes_freeze_dynamic_negative_case4_comma_separator(self):
argv_input = "inp1:1[1, 2, -1]->[1.0 2.0 3.0]" argv_input = "inp1:1[1, 2, -1]->[1.0 2.0 3.0]"
self.assertRaises(Error, input_to_input_cut_info, argv_input) self.assertRaises(Error, parse_inputs, argv_input)
def test_not_supported_arrow(self): def test_not_supported_arrow(self):
with self.assertRaisesRegex(Exception, with self.assertRaisesRegex(Exception,
"Incorrect format of input."): "Incorrect format of input."):
input_to_input_cut_info("inp1->[1.0]") argv_input = parse_inputs("inp1->[1.0]")
input_to_input_cut_info(argv_input)
class PositiveChecker(unittest.TestCase): class PositiveChecker(unittest.TestCase):
@ -526,4 +551,3 @@ class TestConvertModelParamsParsing(unittest.TestCase):
assert param_name not in cli_parser._option_string_actions assert param_name not in cli_parser._option_string_actions
else: else:
assert param_name in cli_parser._option_string_actions assert param_name in cli_parser._option_string_actions