[OVC] do not parse inputs for py_api (#19742)
* [OVC] do not parse inputs * fix unit-tests * remove redundant lines, add test case * add one more unit-test * skip None values * replace str with List in test_mo_import_from_memory * corrected type hints, added a safety assert
This commit is contained in:
parent
c1271d1217
commit
9271b79540
@ -6,9 +6,11 @@
|
||||
|
||||
|
||||
import logging as log
|
||||
import numpy as np
|
||||
import sys
|
||||
from distutils.version import LooseVersion
|
||||
from typing import List, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
from openvino.runtime import PartialShape, Dimension
|
||||
|
||||
|
||||
@ -264,10 +266,22 @@ def trace_tf_model(model, input_shapes, input_types, example_input):
|
||||
|
||||
input_needs_packing = True
|
||||
|
||||
def are_shapes_defined(shape: Union[List, Dict]):
|
||||
if shape is None:
|
||||
return False
|
||||
assert hasattr(shape, '__len__')
|
||||
if len(shape) == 0:
|
||||
return False
|
||||
|
||||
if isinstance(shape, list):
|
||||
return np.all([shape is not None for shape in input_shapes])
|
||||
elif isinstance(shape, dict):
|
||||
return np.all([shape is not None for name, shape in input_shapes.items()])
|
||||
|
||||
if example_input is not None:
|
||||
concrete_func = get_concrete_func(tf_function, example_input, input_needs_packing,
|
||||
"Could not trace the TF model with the following error: {}")
|
||||
elif input_shapes is not None:
|
||||
elif are_shapes_defined(input_shapes):
|
||||
inp = create_example_input_by_user_shapes(input_shapes, input_types)
|
||||
concrete_func = get_concrete_func(tf_function, inp, input_needs_packing,
|
||||
"Could not trace the TF model with the following error: {}")
|
||||
|
@ -187,8 +187,7 @@ def create_pytorch_nn_module_case3(tmp_dir):
|
||||
sample_input1 = torch.zeros(1, 3, 10, 10)
|
||||
sample_input2 = torch.zeros(1, 3, 10, 10)
|
||||
sample_input = tuple([sample_input1, sample_input2])
|
||||
|
||||
return pt_model, ref_model, {'input': "[?,3,?,?],[?,3,?,?]",
|
||||
return pt_model, ref_model, {'input': [[-1, 3, -1, -1], [-1, 3, -1, -1]],
|
||||
'example_input': sample_input}
|
||||
|
||||
|
||||
@ -1093,6 +1092,37 @@ class ConvertRaises(unittest.TestCase):
|
||||
with self.assertRaisesRegex(TypeError, ".*got an unexpected keyword argument 'example_inputs'.*"):
|
||||
convert_model(pytorch_model, example_inputs=(torch.tensor(1),))
|
||||
|
||||
def test_incorrect_inputs_1(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
|
||||
|
||||
with self.assertRaisesRegex(Exception, ".*No node with name.*"):
|
||||
convert_model(pytorch_model, input='input1[1, 10]')
|
||||
|
||||
def test_incorrect_inputs_2(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
|
||||
|
||||
# check that it accepts specified names as is, without parsing into 2 different inputs
|
||||
with self.assertRaisesRegex(Exception, 'No node with name input1,input2'):
|
||||
convert_model(pytorch_model, input='input1,input2')
|
||||
|
||||
def test_incorrect_inputs_3(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
|
||||
|
||||
# check that it accepts specified names as is, without parsing into 2 different inputs
|
||||
with self.assertRaisesRegex(Exception, 'No node with name input1\[1, 10\],input2\[2, 100\]'):
|
||||
convert_model(pytorch_model, input='input1[1, 10],input2[2, 100]')
|
||||
|
||||
def test_incorrect_inputs_4(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
pytorch_model, _, _ = create_pytorch_nn_module_case1('')
|
||||
|
||||
# check that it accepts specified names as is, without parsing into 2 different inputs
|
||||
with self.assertRaisesRegex(Exception, 'No node with name input1\[1, 10\]'):
|
||||
convert_model(pytorch_model, input=['input1[1, 10]', 'input2[2, 100]'])
|
||||
|
||||
def test_failed_extension(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
from openvino.frontend.pytorch import ConversionExtension
|
||||
|
@ -938,3 +938,20 @@ class TestTFLoadByModel(unittest.TestCase):
|
||||
fe = fem.load_by_model(model)
|
||||
assert fe is not None
|
||||
assert fe.get_name() == "tf"
|
||||
|
||||
|
||||
class TestTFConvertRaises(unittest.TestCase):
|
||||
def test_incorrect_inputs_1(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
tf_model, _, _ = create_keras_model('')
|
||||
|
||||
with self.assertRaisesRegex(Exception, ".*No node with name.*"):
|
||||
convert_model(tf_model, input='Input1[1, 2, 3]')
|
||||
|
||||
def test_incorrect_inputs_2(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
tf_model, _, _ = create_keras_model('')
|
||||
|
||||
# check that it accepts specified names as is without parsing into 2 different inputs
|
||||
with self.assertRaisesRegex(Exception, 'No node with name Input1\[1, 2, 3\],Input2\[1, 2, 3\]'):
|
||||
convert_model(tf_model, input='Input1[1, 2, 3],Input2[1, 2, 3]')
|
||||
|
@ -44,11 +44,8 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type,
|
||||
:return: InputCutInfo
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
# Parse params from string
|
||||
node_name, shape = parse_input_value(input)
|
||||
# pylint: disable=no-member
|
||||
return _InputCutInfo(node_name,
|
||||
PartialShape(shape) if shape is not None else None)
|
||||
return _InputCutInfo(input, None)
|
||||
if isinstance(input, (tuple, list)) or is_shape_type(input):
|
||||
# If input represents list with shape, wrap it to list. Single PartialShape also goes to this condition.
|
||||
# Check of all dimensions will be in is_shape_type(val) method below
|
||||
@ -68,10 +65,10 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type,
|
||||
if inp_type is not None:
|
||||
raise Exception("More than one input type provided: {}".format(input))
|
||||
inp_type = val
|
||||
elif is_shape_type(val):
|
||||
elif is_shape_type(val) or val is None:
|
||||
if shape is not None:
|
||||
raise Exception("More than one input shape provided: {}".format(input))
|
||||
shape = PartialShape(val)
|
||||
shape = PartialShape(val) if val is not None else None
|
||||
else:
|
||||
raise Exception("Incorrect input parameters provided. Expected tuple with input name, "
|
||||
"input type or input shape. Got unknown object: {}".format(val))
|
||||
@ -116,7 +113,20 @@ def is_single_input(input: [tuple, list]):
|
||||
return True
|
||||
|
||||
|
||||
def input_to_input_cut_info(input: [str, tuple, list]):
|
||||
def parse_inputs(inputs: str):
|
||||
inputs_list = []
|
||||
# Split to list of string
|
||||
for input_value in split_inputs(inputs):
|
||||
|
||||
# Parse string with parameters for single input
|
||||
node_name, shape = parse_input_value(input_value)
|
||||
# pylint: disable=no-member
|
||||
inputs_list.append((node_name, shape))
|
||||
return inputs_list
|
||||
|
||||
|
||||
def input_to_input_cut_info(input: [dict, tuple, list]):
|
||||
|
||||
"""
|
||||
Parses 'input' to list of InputCutInfo.
|
||||
:param input: input cut parameters passed by user
|
||||
@ -124,18 +134,10 @@ def input_to_input_cut_info(input: [str, tuple, list]):
|
||||
"""
|
||||
if input is None:
|
||||
return []
|
||||
if isinstance(input, str):
|
||||
inputs = []
|
||||
# Split to list of string
|
||||
for input_value in split_inputs(input):
|
||||
|
||||
# Parse string with parameters for single input
|
||||
node_name, shape = parse_input_value(input_value)
|
||||
# pylint: disable=no-member
|
||||
inputs.append(_InputCutInfo(node_name,
|
||||
PartialShape(shape) if shape is not None else None))
|
||||
return inputs
|
||||
if isinstance(input, (tuple, list)):
|
||||
if len(input) == 0:
|
||||
return []
|
||||
# Case when input is single shape set in tuple
|
||||
if len(input) > 0 and isinstance(input[0], (int, Dimension)):
|
||||
input = [input]
|
||||
|
@ -22,7 +22,7 @@ from openvino.tools.ovc.moc_frontend.pipeline import moc_pipeline
|
||||
from openvino.tools.ovc.moc_frontend.moc_emit_ir import moc_emit_ir
|
||||
from openvino.tools.ovc.convert_data_type import destination_type_to_np_data_type
|
||||
from openvino.tools.ovc.cli_parser import get_available_front_ends, get_common_cli_options, depersonalize, \
|
||||
get_mo_convert_params, input_to_input_cut_info
|
||||
get_mo_convert_params, input_to_input_cut_info, parse_inputs
|
||||
from openvino.tools.ovc.help import get_convert_model_help_specifics
|
||||
|
||||
from openvino.tools.ovc.error import Error, FrameworkError
|
||||
@ -93,7 +93,14 @@ def arguments_post_parsing(argv: argparse.Namespace):
|
||||
if is_verbose(argv):
|
||||
print_argv(argv)
|
||||
|
||||
params_parsing(argv)
|
||||
import re
|
||||
if argv.is_python_api_used and isinstance(argv.input, str):
|
||||
argv.input = [argv.input]
|
||||
|
||||
if not argv.is_python_api_used and isinstance(argv.input, str):
|
||||
argv.input = parse_inputs(argv.input)
|
||||
|
||||
normalize_inputs(argv)
|
||||
log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes))
|
||||
|
||||
if not hasattr(argv, 'output') or argv.output is None:
|
||||
@ -297,9 +304,9 @@ def input_model_is_object(input_model):
|
||||
return True
|
||||
|
||||
|
||||
def params_parsing(argv: argparse.Namespace):
|
||||
def normalize_inputs(argv: argparse.Namespace):
|
||||
"""
|
||||
Parses params passed to convert_model and wraps resulting values into dictionaries or lists.
|
||||
repacks params passed to convert_model and wraps resulting values into dictionaries or lists.
|
||||
After working of this method following values are set in argv:
|
||||
|
||||
argv.input, argv.inputs_list - list of input names. Both values are used in some parts of MO.
|
||||
|
@ -144,14 +144,14 @@ class TestMoFreezePlaceholderTFFE(unittest.TestCase):
|
||||
# new frontend
|
||||
(
|
||||
"model_add_with_undefined_constant.pbtxt",
|
||||
"x[2,3]",
|
||||
("x", [2, 3]),
|
||||
{"x": np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32)},
|
||||
np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32),
|
||||
np.float32
|
||||
),
|
||||
(
|
||||
"model_mul_with_undefined_constant.pbtxt",
|
||||
"x[2]",
|
||||
("x", [2]),
|
||||
{"x": np.array([11, -12], dtype=np.int32)},
|
||||
np.array([0, 0], dtype=np.int32),
|
||||
np.int32
|
||||
|
@ -13,7 +13,7 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.ovc.cli_parser import input_to_input_cut_info, check_positive, writable_dir, \
|
||||
readable_file_or_object, get_all_cli_parser, get_mo_convert_params
|
||||
readable_file_or_object, get_all_cli_parser, get_mo_convert_params, parse_inputs
|
||||
from openvino.tools.ovc.convert_impl import pack_params_to_args_namespace, arguments_post_parsing, args_to_argv
|
||||
from openvino.tools.ovc.error import Error
|
||||
from unit_tests.ovc.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry
|
||||
@ -26,6 +26,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes2(self):
|
||||
# shapes specified using --input command line parameter and no values
|
||||
argv_input = "inp1[1,22,333,123],inp2[-1,45,7,1]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([1,22,333,123])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([-1,45,7,1]))]
|
||||
@ -33,30 +34,31 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_raises_get_shapes_1(self):
|
||||
argv_input = "[h,y]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_raises_get_shapes_2(self):
|
||||
argv_input = "(2, 3)"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_raises_get_shapes_3(self):
|
||||
argv_input = "input_1(2, 3)"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_raises_get_shapes_4(self):
|
||||
argv_input = "(2, 3),(10, 10)"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_raises_get_shapes_5(self):
|
||||
argv_input = "<2,3,4>"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_raises_get_shapes_6(self):
|
||||
argv_input = "sd<2,3>"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_get_shapes_complex_input(self):
|
||||
argv_input = "[10, -1, 100],mask[],[?,?]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(shape=PartialShape([10, -1, 100])),
|
||||
_InputCutInfo(name='mask', shape=PartialShape([])),
|
||||
@ -66,6 +68,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1,inp2"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1'),
|
||||
_InputCutInfo(name='inp2')]
|
||||
@ -75,6 +78,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_and_freezing_with_scalar(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1,inp2[]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1'),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([]))]
|
||||
@ -83,6 +87,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes3(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1],inp2[3,2,3],inp3[5]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3])),
|
||||
@ -92,6 +97,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes3_comma_sep(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1],inp2[3 2 3],inp3[5]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3])),
|
||||
@ -101,6 +107,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes6(self):
|
||||
# 0D value for freezing specified using --input command line parameter without shape
|
||||
argv_input = "inp1[3,1],inp2[3,2,3],inp3"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs_list, result, _ = input_to_input_cut_info(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
@ -111,6 +118,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes7(self):
|
||||
# 0D shape and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3,1],inp2[3,2,3],inp3[]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs_list, result, _ = input_to_input_cut_info(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
@ -121,6 +129,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_get_shapes_and_data_types_shape_only(self):
|
||||
argv_input = "placeholder1[3 1],placeholder2,placeholder3"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='placeholder1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='placeholder2'),
|
||||
@ -129,6 +138,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_get_shapes_and_data_types_shape_with_ports_only(self):
|
||||
argv_input = "placeholder1:4[3 1],placeholder2,2:placeholder3"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='placeholder1:4', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='placeholder2'),
|
||||
@ -137,15 +147,16 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_wrong_data_types(self):
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{abracadabra},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_shape_and_value_shape_mismatch(self):
|
||||
# size of value tensor does not correspond to specified shape for the third node
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3],inp3[5 3]->[2.0 3.0 5.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_get_shapes_no_input_no_shape(self):
|
||||
argv_input = ""
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = []
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
@ -153,18 +164,21 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_get_shapes_no_input_one_shape2(self):
|
||||
argv_input = "[12,4,1]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1]))]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_for_scalar_inputs(self):
|
||||
argv_input = "[]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(shape=PartialShape([]))]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_two_input_shapes_with_scalar(self):
|
||||
argv_input = "[12,4,1],[]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])),
|
||||
_InputCutInfo(shape=PartialShape([]))]
|
||||
@ -172,6 +186,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_get_shapes_two_input_shapes(self):
|
||||
argv_input = "[12,4,1],[10]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(shape=PartialShape([12,4,1])),
|
||||
_InputCutInfo(shape=PartialShape([10])),]
|
||||
@ -179,6 +194,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_get_shapes_one_input_no_shape(self):
|
||||
argv_input = "inp1"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1')]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
@ -186,6 +202,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes2(self):
|
||||
# shapes specified using --input command line parameter and no values
|
||||
argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))]
|
||||
@ -194,6 +211,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes3(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
|
||||
@ -203,6 +221,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes6(self):
|
||||
# 0D value for freezing specified using --input command line parameter without shape
|
||||
argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
|
||||
@ -212,6 +231,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes7(self):
|
||||
# 0D shape and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1],inp2[3.. ..2 5..10 ? -1],inp3[]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
|
||||
@ -220,15 +240,16 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case1(self):
|
||||
argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case2(self):
|
||||
argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes2_comma_separator(self):
|
||||
# shapes specified using --input command line parameter and no values
|
||||
argv_input = "inp1[1,?,50..100,123],inp2[-1,45..,..7,1]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape("[1,?,50..100,123]")),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[-1,45..,..7,1]"))]
|
||||
@ -237,6 +258,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes3_comma_separator(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3,1],inp2[3..,..2,5..10,?,-1],inp3[5]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
|
||||
@ -246,6 +268,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes6_comma_separator(self):
|
||||
# 0D value for freezing specified using --input command line parameter without shape
|
||||
argv_input = "inp1[3, 1],inp2[3.., ..2, 5..10, ?,-1],inp3"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
|
||||
@ -255,6 +278,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes7_comma_separator(self):
|
||||
# 0D shape and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3,1],inp2[3.., ..2,5..10, ?,-1],inp3[]"
|
||||
argv_input = parse_inputs(argv_input)
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape("[3..,..2,5..10,?,-1]")),
|
||||
@ -263,24 +287,25 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case1_comma_separator(self):
|
||||
argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case2_comma_separator(self):
|
||||
argv_input = "inp1:1[1,2,-1]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case3_comma_separator(self):
|
||||
argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case4_comma_separator(self):
|
||||
argv_input = "inp1:1[1, 2, -1]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
self.assertRaises(Error, parse_inputs, argv_input)
|
||||
|
||||
def test_not_supported_arrow(self):
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Incorrect format of input."):
|
||||
input_to_input_cut_info("inp1->[1.0]")
|
||||
argv_input = parse_inputs("inp1->[1.0]")
|
||||
input_to_input_cut_info(argv_input)
|
||||
|
||||
|
||||
class PositiveChecker(unittest.TestCase):
|
||||
@ -526,4 +551,3 @@ class TestConvertModelParamsParsing(unittest.TestCase):
|
||||
assert param_name not in cli_parser._option_string_actions
|
||||
else:
|
||||
assert param_name in cli_parser._option_string_actions
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user