Improvements and fixes in OVC convert_model (#19184)
* Added support of tuple in input, removed type syntax from OVC tool. * Removed type syntax tests. * Apply suggestions from code review * Method annotation corrected. * Type annotation corrected. --------- Co-authored-by: Sergey Lyalin <sergey.lyalin@intel.com>
This commit is contained in:
parent
b656feee57
commit
3e23908983
@ -91,6 +91,18 @@ class TestComplexParams(CommonMOConvertTest):
|
||||
{'params_test': {'input': [ov.Type.f32, ov.Type.f32]},
|
||||
'params_ref': {'input': 'Input1{f32},Input2{f32}'}},
|
||||
{'params_test': {'input': [([1, 3, -1, -1], ov.Type.i32), ov.Type.i32, ov.Type.i32]},
|
||||
'params_ref': {'input': 'Input1[1,3,?,?]{i32},Input2{i32},Input3{i32}'}},
|
||||
{'params_test': {'input': (PartialShape([2, 3, 4]), [2, 3, 4], [Dimension(2), Dimension(3), Dimension(4)])},
|
||||
'params_ref': {'input_shape': "[2,3,4],[2,3,4],[2,3,4]", 'input': 'Input1,Input2,Input3'}},
|
||||
{'params_test': {'input': (PartialShape([1, 3, -1, -1]), [1, 3, -1, -1])},
|
||||
'params_ref': {'input_shape': "[1,3,?,?],[1,3,?,?]", 'input': 'Input1,Input2'}},
|
||||
{'params_test': {'input': ((2, 3, 4), [2, 3, 4], (Dimension(2), Dimension(3), Dimension(4)))},
|
||||
'params_ref': {'input_shape': "[2,3,4],[2,3,4],[2,3,4]", 'input': 'Input1,Input2,Input3'}},
|
||||
{'params_test': {'input': (np.int32, Type(np.int32), np.int32)},
|
||||
'params_ref': {'input': 'Input1{i32},Input2{i32},Input3{i32}'}},
|
||||
{'params_test': {'input': (ov.Type.f32, ov.Type.f32)},
|
||||
'params_ref': {'input': 'Input1{f32},Input2{f32}'}},
|
||||
{'params_test': {'input': (([1, 3, -1, -1], ov.Type.i32), ov.Type.i32, ov.Type.i32)},
|
||||
'params_ref': {'input': 'Input1[1,3,?,?]{i32},Input2{i32},Input3{i32}'}}
|
||||
]
|
||||
|
||||
|
@ -176,7 +176,7 @@ def create_pytorch_nn_module_with_scalar_input(tmp_dir):
|
||||
sample_input2 = torch.zeros(1, 3, 10, 10)
|
||||
sample_input = sample_input1, sample_input2
|
||||
|
||||
return pt_model, ref_model, {'input': ["[]", PartialShape([-1, 3, -1, -1])],
|
||||
return pt_model, ref_model, {'input': [PartialShape("[]"), PartialShape([-1, 3, -1, -1])],
|
||||
'example_input': sample_input}
|
||||
|
||||
|
||||
|
@ -45,15 +45,14 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type,
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
# Parse params from string
|
||||
node_name, shape, data_type = parse_input_value(input)
|
||||
node_name, shape = parse_input_value(input)
|
||||
# pylint: disable=no-member
|
||||
return _InputCutInfo(node_name,
|
||||
PartialShape(shape) if shape is not None else None,
|
||||
data_type)
|
||||
PartialShape(shape) if shape is not None else None)
|
||||
if isinstance(input, (tuple, list)) or is_shape_type(input):
|
||||
# If input represents list with shape, wrap it to list. Single PartialShape also goes to this condition.
|
||||
# Check of all dimensions will be in is_shape_type(val) method below
|
||||
if len(input) > 0 and isinstance(input[0], (int, Dimension)):
|
||||
if len(input) > 0 and isinstance(input[0], (int, Dimension)) or isinstance(input, PartialShape):
|
||||
input = [input]
|
||||
|
||||
# Check values of tuple or list and collect to InputCutInfo
|
||||
@ -90,6 +89,32 @@ def single_input_to_input_cut_info(input: [str, tuple, list, PartialShape, Type,
|
||||
|
||||
raise Exception("Unexpected object provided for input. Expected tuple, Shape, PartialShape, Type or str. Got {}".format(type(input)))
|
||||
|
||||
def is_single_input(input: [tuple, list]):
|
||||
"""
|
||||
Checks if input has parameters for single input.
|
||||
:param input: list or tuple of input parameters or input shape or input name.
|
||||
:return: True if input has parameters for single input, otherwise False.
|
||||
"""
|
||||
name = None
|
||||
inp_type = None
|
||||
shape = None
|
||||
for val in input:
|
||||
if isinstance(val, str):
|
||||
if name is not None:
|
||||
return False
|
||||
name = val
|
||||
elif isinstance(val, (type, Type)):
|
||||
if inp_type is not None:
|
||||
return False
|
||||
inp_type = val
|
||||
elif is_shape_type(val):
|
||||
if shape is not None:
|
||||
return False
|
||||
shape = PartialShape(val)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def input_to_input_cut_info(input: [str, tuple, list]):
|
||||
"""
|
||||
@ -105,27 +130,24 @@ def input_to_input_cut_info(input: [str, tuple, list]):
|
||||
for input_value in split_inputs(input):
|
||||
|
||||
# Parse string with parameters for single input
|
||||
node_name, shape, data_type = parse_input_value(input_value)
|
||||
node_name, shape = parse_input_value(input_value)
|
||||
# pylint: disable=no-member
|
||||
inputs.append(_InputCutInfo(node_name,
|
||||
PartialShape(shape) if shape is not None else None,
|
||||
data_type))
|
||||
PartialShape(shape) if shape is not None else None))
|
||||
return inputs
|
||||
if isinstance(input, tuple):
|
||||
if isinstance(input, (tuple, list)):
|
||||
# Case when input is single shape set in tuple
|
||||
if len(input) > 0 and isinstance(input[0], (int, Dimension)):
|
||||
input = [input]
|
||||
# Case when input is set as tuple. Expected that it is always single input.
|
||||
return [single_input_to_input_cut_info(input)]
|
||||
if isinstance(input, list):
|
||||
# Case when input is single shape set in list
|
||||
if len(input) > 0 and isinstance(input[0], (int, Dimension)):
|
||||
input = [input]
|
||||
|
||||
if is_single_input(input):
|
||||
return [single_input_to_input_cut_info(input)]
|
||||
|
||||
inputs = []
|
||||
# Case when input is set as list. Expected that it is list of params for different inputs.
|
||||
for inp in input:
|
||||
inputs.append(single_input_to_input_cut_info(inp))
|
||||
return inputs
|
||||
|
||||
if isinstance(input, dict):
|
||||
res_list = []
|
||||
for name, value in input.items():
|
||||
@ -480,7 +502,7 @@ def input_model_details(model):
|
||||
|
||||
def get_common_cli_options(argv, is_python_api_used):
|
||||
d = OrderedDict()
|
||||
d['input_model'] = ['- Path to the Input Model', input_model_details]
|
||||
d['input_model'] = ['- Input Model', input_model_details]
|
||||
if not is_python_api_used:
|
||||
model_name = get_model_name_from_args(argv)
|
||||
d['output_model'] = ['- IR output name', lambda _: model_name]
|
||||
@ -509,25 +531,6 @@ def get_all_cli_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def remove_data_type_from_input_value(input_value: str):
|
||||
"""
|
||||
Removes the type specification from the input string. The type specification is a string enclosed with curly braces.
|
||||
:param input_value: string passed as input to the "input" command line parameter
|
||||
:return: string without type specification
|
||||
"""
|
||||
return re.sub(r'\{.*\}', '', input_value)
|
||||
|
||||
|
||||
def get_data_type_from_input_value(input_value: str):
|
||||
"""
|
||||
Returns the numpy data type corresponding to the data type specified in the input value string
|
||||
:param input_value: string passed as input to the "input" command line parameter
|
||||
:return: the corresponding numpy data type and None if the data type is not specified in the input value
|
||||
"""
|
||||
data_type_match = re.match(r'.*\{(.*)\}.*', input_value)
|
||||
return destination_type_to_np_data_type(data_type_match.group(1)) if data_type_match is not None else None
|
||||
|
||||
|
||||
def remove_shape_from_input_value(input_value: str):
|
||||
"""
|
||||
Removes the shape specification from the input string. The shape specification is a string enclosed with square
|
||||
@ -570,7 +573,7 @@ def get_node_name_with_port_from_input_value(input_value: str):
|
||||
:param input_value: string passed as input to the "input" command line parameter
|
||||
:return: the corresponding node name with input/output port
|
||||
"""
|
||||
return remove_shape_from_input_value(remove_data_type_from_input_value(input_value))
|
||||
return remove_shape_from_input_value(input_value)
|
||||
|
||||
|
||||
def partial_shape_prod(shape: [PartialShape, tuple]):
|
||||
@ -591,199 +594,18 @@ def parse_input_value(input_value: str):
|
||||
Parameters
|
||||
----------
|
||||
input_value
|
||||
string with a specified node name, shape, value and data_type.
|
||||
E.g. 'node_name:0[4]{fp32}'
|
||||
string with a specified node name and shape.
|
||||
E.g. 'node_name:0[4]'
|
||||
|
||||
Returns
|
||||
-------
|
||||
Node name, shape, value, data type
|
||||
E.g. 'node_name:0', '4', [1.0 2.0 3.0 4.0], np.float32
|
||||
"""
|
||||
data_type = get_data_type_from_input_value(input_value)
|
||||
node_name = get_node_name_with_port_from_input_value(input_value)
|
||||
shape = get_shape_from_input_value(input_value)
|
||||
|
||||
return node_name if node_name else None, shape, data_type
|
||||
|
||||
|
||||
def split_str_avoiding_square_brackets(s: str) -> list:
|
||||
"""
|
||||
Splits a string by comma, but skips commas inside square brackets.
|
||||
:param s: string to split
|
||||
:return: list of strings split by comma
|
||||
"""
|
||||
res = list()
|
||||
skipping = 0
|
||||
last_idx = 0
|
||||
for i, c in enumerate(s):
|
||||
if c == '[':
|
||||
skipping += 1
|
||||
elif c == ']':
|
||||
skipping -= 1
|
||||
elif c == ',' and skipping == 0:
|
||||
res.append(s[last_idx:i])
|
||||
last_idx = i + 1
|
||||
res.append(s[last_idx:])
|
||||
return res
|
||||
|
||||
|
||||
def split_layouts_by_arrow(s: str) -> tuple:
|
||||
"""
|
||||
Splits a layout string by first arrow (->).
|
||||
:param s: string to split
|
||||
:return: tuple containing source and target layouts
|
||||
"""
|
||||
arrow = s.find('->')
|
||||
if arrow != -1:
|
||||
source_layout = s[:arrow]
|
||||
target_layout = s[arrow + 2:]
|
||||
if source_layout == '':
|
||||
source_layout = None
|
||||
if target_layout == '':
|
||||
target_layout = None
|
||||
return source_layout, target_layout
|
||||
else:
|
||||
return s, None
|
||||
|
||||
|
||||
def validate_layout(layout: str):
|
||||
"""
|
||||
Checks if layout is of valid format.
|
||||
:param layout: string containing layout
|
||||
:raises: if layout is incorrect
|
||||
"""
|
||||
error_msg = 'Invalid layout parsed: {}'.format(layout)
|
||||
if layout:
|
||||
incorrect_brackets = xor(layout[0] == '[', layout[-1] == ']')
|
||||
if incorrect_brackets or layout[-1] == '-':
|
||||
error_msg += ', did you forget quotes?'
|
||||
else:
|
||||
valid_layout_re = re.compile(r'\[?[^\[\]\(\)\-\s]*\]?')
|
||||
if valid_layout_re.fullmatch(layout):
|
||||
return
|
||||
raise Error(error_msg)
|
||||
|
||||
|
||||
def write_found_layout(name: str, found_layout: str, parsed: dict, dest: str = None):
|
||||
"""
|
||||
Writes found layout data to the 'parsed' dict.
|
||||
:param name: name of the node to add layout
|
||||
:param found_layout: string containing layout for the node
|
||||
:param parsed: dict where result will be stored
|
||||
:param dest: type of the command line:
|
||||
* 'source' is "source_layout"
|
||||
* 'target' is "target_layout"
|
||||
* None is "layout"
|
||||
"""
|
||||
s_layout = None
|
||||
t_layout = None
|
||||
if name in parsed:
|
||||
s_layout = parsed[name]['source_layout']
|
||||
t_layout = parsed[name]['target_layout']
|
||||
if dest == 'source':
|
||||
s_layout = found_layout
|
||||
elif dest == 'target':
|
||||
t_layout = found_layout
|
||||
else:
|
||||
s_layout, t_layout = split_layouts_by_arrow(found_layout)
|
||||
validate_layout(s_layout)
|
||||
validate_layout(t_layout)
|
||||
parsed[name] = {'source_layout': s_layout, 'target_layout': t_layout}
|
||||
|
||||
|
||||
def write_found_layout_list(idx: int, found_layout: str, parsed: list, dest: str = None):
|
||||
"""
|
||||
Writes found layout data to the 'parsed' dict.
|
||||
:param idx: idx of of the node to add layout
|
||||
:param found_layout: string containing layout for the node
|
||||
:param parsed: list where result will be stored
|
||||
:param dest: type of the command line:
|
||||
* 'source' is "source_layout"
|
||||
* 'target' is "target_layout"
|
||||
* None is "layout"
|
||||
"""
|
||||
s_layout = None
|
||||
t_layout = None
|
||||
if idx < len(parsed):
|
||||
s_layout = parsed[idx]['source_layout']
|
||||
t_layout = parsed[idx]['target_layout']
|
||||
if dest == 'source':
|
||||
s_layout = found_layout
|
||||
elif dest == 'target':
|
||||
t_layout = found_layout
|
||||
else:
|
||||
s_layout, t_layout = split_layouts_by_arrow(found_layout)
|
||||
validate_layout(s_layout)
|
||||
validate_layout(t_layout)
|
||||
|
||||
if idx < len(parsed):
|
||||
parsed[idx] = {'source_layout': s_layout, 'target_layout': t_layout}
|
||||
else:
|
||||
parsed.append({'source_layout': s_layout, 'target_layout': t_layout})
|
||||
|
||||
|
||||
def parse_layouts_by_destination(s: str, parsed: dict, parsed_list: list, dest: str = None) -> None:
|
||||
"""
|
||||
Parses layout command line to get all names and layouts from it. Adds all found data in the 'parsed' dict.
|
||||
:param s: string to parse
|
||||
:param parsed: dict where result will be stored
|
||||
:param dest: type of the command line:
|
||||
* 'source' is "source_layout"
|
||||
* 'target' is "target_layout"
|
||||
* None is "layout"
|
||||
"""
|
||||
list_s = split_str_avoiding_square_brackets(s)
|
||||
if len(list_s) == 1 and (list_s[0][-1] not in ')]' or (list_s[0][0] == '[' and list_s[0][-1] == ']')):
|
||||
# single layout case
|
||||
write_found_layout('', list_s[0], parsed, dest)
|
||||
else:
|
||||
for idx, layout_str in enumerate(list_s):
|
||||
# case for: "name1(nhwc->[n,c,h,w])"
|
||||
p1 = re.compile(r'([^\[\]\(\)]*)\((\S+)\)')
|
||||
m1 = p1.match(layout_str)
|
||||
# case for: "name1[n,h,w,c]->[n,c,h,w]"
|
||||
p2 = re.compile(r'([^\[\]\(\)]*)(\[\S*\])')
|
||||
m2 = p2.match(layout_str)
|
||||
if m1:
|
||||
found_g = m1.groups()
|
||||
elif m2:
|
||||
found_g = m2.groups()
|
||||
else:
|
||||
# case for layout without name
|
||||
write_found_layout_list(idx, layout_str, parsed_list, dest)
|
||||
continue
|
||||
if len(found_g[0]) > 0:
|
||||
write_found_layout(found_g[0], found_g[1], parsed, dest)
|
||||
else:
|
||||
write_found_layout_list(idx, found_g[1], parsed_list, dest)
|
||||
|
||||
|
||||
def get_layout_values(argv_layout: str = '', argv_source_layout: str = '', argv_target_layout: str = ''):
|
||||
"""
|
||||
Parses layout string.
|
||||
:param argv_layout: string with a list of layouts passed as a "layout".
|
||||
:param argv_source_layout: string with a list of layouts passed as a "source_layout".
|
||||
:param argv_target_layout: string with a list of layouts passed as a "target_layout".
|
||||
:return: dict with names and layouts associated
|
||||
"""
|
||||
if argv_layout and (argv_source_layout or argv_target_layout):
|
||||
raise Error("\"layout\" is used as well as \"source_layout\" and/or \"target_layout\" which is not allowed, please "
|
||||
"use one of them.")
|
||||
res = {}
|
||||
res_list = []
|
||||
if argv_layout:
|
||||
parse_layouts_by_destination(argv_layout, res, res_list)
|
||||
if argv_source_layout:
|
||||
parse_layouts_by_destination(argv_source_layout, res, res_list, 'source')
|
||||
if argv_target_layout:
|
||||
parse_layouts_by_destination(argv_target_layout, res, res_list, 'target')
|
||||
if len(res) > 0 and len(res_list) > 0:
|
||||
raise Error("Some layout values are provided with names, and some without names. "
|
||||
"Please provide ether all layouts with names or all layouts without names.")
|
||||
if len(res) > 0:
|
||||
return res
|
||||
else:
|
||||
return res_list
|
||||
return node_name if node_name else None, shape
|
||||
|
||||
|
||||
def split_inputs(input_str):
|
||||
|
@ -17,15 +17,15 @@ def get_convert_model_help_specifics():
|
||||
{'description':
|
||||
'Information of model input required for model conversion. '
|
||||
'This is a comma separated list with optional '
|
||||
'input names, shapes and data types. The order of inputs '
|
||||
'input names and shapes. The order of inputs '
|
||||
'in converted model will match the order of '
|
||||
'specified inputs. The shape is specified as comma-separated list. '
|
||||
'The data type of input node is specified in braces and can have one of '
|
||||
'the values: f64, f32, f16, i64, i32, u8, boolean. If data type is not '
|
||||
'specified explicitly then data type is taken from the '
|
||||
'original node data type. Example, to set `input_1` input '
|
||||
'with shape [1,100] and float32 type, and `sequence_len` input '
|
||||
'with int32 type \"input_1[1,100]{f32},sequence_len{i32}\".'},
|
||||
'Example, to set `input_1` input with shape [1,100] and `sequence_len` input '
|
||||
'with shape [1,?]: \"input_1[1,100],sequence_len[1,?]\", where "?" is a dynamic dimension, '
|
||||
'which means that such a dimension can be specified later in the runtime. '
|
||||
'If the dimension is set as an integer (like 100 in [1,100]), such a dimension is not supposed '
|
||||
'to be changed later, during a model conversion it is treated as a static value. '
|
||||
'Example with unnamed inputs: \"[1,100],[1,?]\".'},
|
||||
'extension':
|
||||
{'description':
|
||||
'Paths or a comma-separated list of paths to libraries '
|
||||
|
@ -86,37 +86,6 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
_InputCutInfo(name='inp3', shape=PartialShape([]))]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types1(self):
|
||||
argv_input = "inp1[3 1],inp2[3 2 3]{i32},inp3[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3]), type=np.int32),
|
||||
_InputCutInfo(name='inp3', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_with_input_ports(self):
|
||||
argv_input = "1:inp1[3 1],inp2[3 2 3]{i32},0:inp3[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='1:inp1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3]), type=np.int32),
|
||||
_InputCutInfo(name='0:inp3', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_with_output_ports(self):
|
||||
argv_input = "inp1:1[3 1],inp2[3 2 3]{i32},inp3:4[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1:1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3]), type=np.int32),
|
||||
_InputCutInfo(name='inp3:4', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_with_output_ports_comma_sep(self):
|
||||
argv_input = "inp1:1[3,1],inp2[3,2, 3]{i32},inp3:4[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1:1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2', shape=PartialShape([3,2,3]), type=np.int32),
|
||||
_InputCutInfo(name='inp3:4', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_shape_only(self):
|
||||
argv_input = "placeholder1[3 1],placeholder2,placeholder3"
|
||||
@ -134,14 +103,6 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
_InputCutInfo(name='2:placeholder3')]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_when_no_freeze_value(self):
|
||||
argv_input = "placeholder1{i32}[3 1],placeholder2,placeholder3{i32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='placeholder1', shape=PartialShape([3,1]), type=np.int32),
|
||||
_InputCutInfo(name='placeholder2'),
|
||||
_InputCutInfo(name='placeholder3', type=np.int32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_wrong_data_types(self):
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{abracadabra},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
@ -225,22 +186,6 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
_InputCutInfo(name='inp3', shape=PartialShape([]))]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_input_port(self):
|
||||
argv_input = "inp1:1[3 1],0:inp2[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1:1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='0:inp2', shape=PartialShape("[3..,..2,5..10,?,-1]"), type=np.int32),
|
||||
_InputCutInfo(name='inp3:4', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_output_port(self):
|
||||
argv_input = "inp1:1[3 1],inp2:3[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1:1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2:3', shape=PartialShape("[3..,..2,5..10,?,-1]"), type=np.int32),
|
||||
_InputCutInfo(name='inp3:4', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case1(self):
|
||||
argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
@ -284,22 +229,6 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
_InputCutInfo(name='inp3', shape=PartialShape([]))]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_input_port_comma_separator(self):
|
||||
argv_input = "inp1:1[3,1],0:inp2[ 3.. ,..2, 5..10, ?,-1]{i32},inp3:4[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1:1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='0:inp2', shape=PartialShape("[3..,..2,5..10,?,-1]"), type=np.int32),
|
||||
_InputCutInfo(name='inp3:4', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_output_port_comma_separator(self):
|
||||
argv_input = "inp1:1[3,1],inp2:3[3..,..2,5..10,?,-1]{i32},inp3:4[5]{f32}"
|
||||
inputs = input_to_input_cut_info(argv_input)
|
||||
inputs_ref = [_InputCutInfo(name='inp1:1', shape=PartialShape([3,1])),
|
||||
_InputCutInfo(name='inp2:3', shape=PartialShape("[3..,..2,5..10,?,-1]"), type=np.int32),
|
||||
_InputCutInfo(name='inp3:4', shape=PartialShape([5]), type=np.float32)]
|
||||
self.assertEqual(inputs, inputs_ref)
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case1_comma_separator(self):
|
||||
argv_input = "inp1:1[3,1..10]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, input_to_input_cut_info, argv_input)
|
||||
|
Loading…
Reference in New Issue
Block a user