diff --git a/tools/mo/openvino/tools/mo/front/extractor.py b/tools/mo/openvino/tools/mo/front/extractor.py index 31649d8b30d..ae384ee5548 100644 --- a/tools/mo/openvino/tools/mo/front/extractor.py +++ b/tools/mo/openvino/tools/mo/front/extractor.py @@ -605,7 +605,7 @@ def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, n if ph_id not in _input_shapes] else: # np.ndarray is a shape. User provided only --input_shape key - assert isinstance(input_user_shapes, np.ndarray) + assert isinstance(input_user_shapes, tuple) if len(placeholders_ids) == 1: # There is only one placeholder in the original network _input_shapes[placeholders_ids[0]].append({'shape': input_user_shapes, 'port': None}) @@ -861,7 +861,7 @@ def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, por def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False, - shape=None, data_type=None, is_out_port: bool = False): + shape=None, user_shape=None, data_type=None, is_out_port: bool = False): """ This function adds Input node to node with id==node_id to specified port (in or out defined with is_out_port). :param graph: graph to operate on. @@ -869,6 +869,8 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False, :param port: number of port of node_id node for adding input node. :param data: flag that define whether data nodes is needed or not. :param shape: shape for new input node. + :param user_shape: shape provided by user which may contain boundaries of dynamic dimension. + :param data_type: data type of input node. :param is_out_port: flag that define whether port is output port or not. :return: id of new Input operation """ @@ -876,7 +878,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False, from openvino.tools.mo.ops.parameter import Parameter if data_type is None: data_type = np.float32 - input_op = Parameter(graph, dict(shape=shape, data_type=data_type, initial_node_name=node_id, + input_op = Parameter(graph, dict(shape=shape, user_shape=user_shape, data_type=data_type, initial_node_name=node_id, name=get_new_placeholder_name(node_id, is_out_port, port))) fw_name = Node(graph, node_id).soft_get('name') @@ -901,7 +903,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False, def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node, port: int, node_id: str, - shape: np.array, data_type, + shape: np.array, user_shape: tuple, data_type, inputs: list, edges_to_remove: list): n_inputs = len(smart_node.in_nodes()) if n_inputs > 1 and port is None: @@ -912,7 +914,7 @@ def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node, port = port if port is not None else 0 edges_to_remove.append((smart_node.in_node(port).id, smart_node.id)) inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False, - shape=shape, data_type=data_type)) + shape=shape, user_shape=user_shape, data_type=data_type)) def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node, port:int, node_id: str, @@ -934,13 +936,14 @@ def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node, edges_to_remove.append((in_node.id, node_id)) -def add_input_ops_helper_before_infer_output_port(graph: Graph, port:int, node_id: str, - shape: np.array, data_type, inputs: list, edges_to_remove: list): +def add_input_ops_helper_before_infer_output_port(graph: Graph, port: int, node_id: str, + shape: np.array, user_shape: tuple, data_type: tuple, + inputs: list, edges_to_remove: list): for u, v, edge_attrs in graph.out_edges(node_id, data=True): if edge_attrs['out'] == port: edges_to_remove.append((u, v)) # we need to remove all edges from this port inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False, - shape=shape, data_type=data_type, is_out_port=True)) + shape=shape, user_shape=user_shape, data_type=data_type, is_out_port=True)) def add_input_ops_helper_after_infer_output_port(graph: Graph, smart_node: Node, port:int, node_id: str, @@ -987,8 +990,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool): is_out_port = 'out' in port_and_shape_info # by default we assume input port or input node without port shape = port_and_shape_info['shape'] if 'shape' in port_and_shape_info else None + user_shape = None if shape is not None: - shape = shape_array([dim if dim >= 0 else dynamic_dimension_value for dim in shape]) + user_shape = shape + shape = shape_array( + [dim if type(dim) != tuple and dim >= 0 else dynamic_dimension_value for dim in shape]) data_type = port_and_shape_info['data_type'] if 'data_type' in port_and_shape_info else None smart_node = Node(graph, node_id) @@ -1016,6 +1022,7 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool): refer_to_faq_msg(28), node_id, port) if shape is not None: smart_node['shape'] = shape + smart_node['user_shape'] = user_shape if data_type is not None: smart_node['data_type'] = data_type inputs.append(node_id) @@ -1027,10 +1034,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool): continue # We cut with shapes provided by user and there is no need to wait till infer if is_out_port: - add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, data_type, inputs, - edges_to_remove) + add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, user_shape, + data_type, inputs, edges_to_remove) else: - add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape, data_type, inputs, + add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape, + user_shape, data_type, inputs, edges_to_remove) else: # We cut after infer and we need inferred shapes in nodes diff --git a/tools/mo/openvino/tools/mo/middle/PartialInfer.py b/tools/mo/openvino/tools/mo/middle/PartialInfer.py index 3af6da44959..c53cabeb589 100644 --- a/tools/mo/openvino/tools/mo/middle/PartialInfer.py +++ b/tools/mo/openvino/tools/mo/middle/PartialInfer.py @@ -2,10 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import logging as log -from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, unmask_shape, shape_array, dynamic_dimension_value +from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, shape_array, \ + dynamic_dimension_value from openvino.tools.mo.graph.graph import Graph from openvino.tools.mo.middle.passes.infer import partial_infer from openvino.tools.mo.middle.replacement import MiddleReplacementPattern +from openvino.tools.mo.ops.parameter import Parameter class PartialInfer(MiddleReplacementPattern): @@ -25,7 +27,7 @@ class PartialInfer(MiddleReplacementPattern): param_shape = parameter.soft_get('shape', shape_array(dynamic_dimension_value)) if not is_fully_defined(param_shape): parameter_name = parameter.soft_get('name', parameter.id) - dynamic_inputs[parameter_name] = param_shape + dynamic_inputs[parameter_name] = parameter if dynamic_inputs: log.error('The model contains input(s) with partially defined shapes: {}. ' 'Starting from the 2022.1 release the Model Optimizer can generate an IR with partially defined ' @@ -34,7 +36,7 @@ class PartialInfer(MiddleReplacementPattern): 'call "reshape" method in the Inference Engine and specify static input shapes. For optimal ' 'performance, it is still recommended to update input shapes with fixed ones using "--input" or ' '"--input_shape" command-line parameters.' - .format(','.join('name="{}" shape="{}"'.format(name, unmask_shape(shape)) - for name, shape in dynamic_inputs.items())), + .format(','.join('name="{}" shape="{}"'.format(name, Parameter.shape_serialize(parameter)) + for name, parameter in dynamic_inputs.items())), extra={'is_warning': True}) partial_infer(graph) diff --git a/tools/mo/openvino/tools/mo/moc_frontend/extractor.py b/tools/mo/openvino/tools/mo/moc_frontend/extractor.py index 65651af90a9..a29a957aa25 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/extractor.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/extractor.py @@ -118,7 +118,7 @@ def fe_input_user_data_repack(input_model: InputModel, input_user_shapes: [None, _input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type}) else: _input_shapes.append({'node': node, 'shape': shape}) - elif isinstance(input_user_shapes, np.ndarray): + elif isinstance(input_user_shapes, tuple): model_inputs = input_model.get_inputs() assert len(model_inputs) == 1 _input_shapes.append({'node': model_inputs[0], 'shape': input_user_shapes}) diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py index 11a0f9c04ad..e8773662edc 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py @@ -12,6 +12,8 @@ from openvino.runtime import Dimension, PartialShape # pylint: disable=no from openvino.frontend import FrontEnd, Place # pylint: disable=no-name-in-module,import-error from openvino.runtime.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error +import numpy as np + def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): """ @@ -66,7 +68,7 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): for user_shape in user_shapes: if user_shape.get('shape') is not None: input_model.set_partial_shape( - user_shape['node'], PartialShape(user_shape['shape'])) + user_shape['node'], partial_shape_from_tuple(user_shape['shape'])) if user_shape.get('data_type') is not None: data_type = get_element_type(user_shape['data_type']) log.debug('Set data type: {}'.format(data_type)) @@ -97,3 +99,16 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): ngraph_function = moc_front_end.convert(input_model) return ngraph_function + + +def partial_shape_from_tuple(shape: tuple): + new_shape = [] + for dim in shape: + if isinstance(dim, tuple): + assert len(dim) == 2, "Incorrect boundaries of dimension {} in shape {}".format(dim, shape) + assert dim[0] >= 0, "Incorrect min value of dimension {} in shape".format(dim, shape) + new_shape.append(Dimension(dim[0], dim[1])) + else: + assert isinstance(dim, np.int64), "Incorrect type of dimension {} in shape".format(dim, shape) + new_shape.append(Dimension(dim)) + return PartialShape(new_shape) diff --git a/tools/mo/openvino/tools/mo/ops/parameter.py b/tools/mo/openvino/tools/mo/ops/parameter.py index 13164271229..abaf36be870 100644 --- a/tools/mo/openvino/tools/mo/ops/parameter.py +++ b/tools/mo/openvino/tools/mo/ops/parameter.py @@ -26,6 +26,7 @@ class Parameter(Op): 'type_infer': self.type_infer, 'out_ports_count': 1, + 'user_shape': None, } if 'data_type' not in attrs: mandatory_props['data_type'] = np.float32 @@ -35,9 +36,25 @@ class Parameter(Op): def type_infer(node): node.out_port(0).set_data_type(node.data_type) + @staticmethod + def shape_serialize(node): + def serialize_dimension(dim: [tuple, np.int64]): + if type(dim) == tuple: + assert len(dim) == 2, "Unable to serialize shape {} in node {}".format(node.soft_get('user_shape'), + node.soft_get('name', node.id)) + min_str = str(dim[0]) if dim[0] > 0 else "" + max_str = str(dim[1]) if dim[1] < np.iinfo(np.int64).max else "" + return min_str + ".." + max_str + return str(dim) + + if not node.has_valid('user_shape'): + return ','.join([str(i) for i in unmask_shape(node.shape)]) + shape = node.soft_get('user_shape') + return ','.join(map(serialize_dimension, shape)) + def supported_attrs(self): return [ - ('shape', lambda node: ','.join([str(i) for i in unmask_shape(node.shape)])), + ('shape', lambda node: self.shape_serialize(node)), ('element_type', lambda node: np_data_type_to_destination_type(node.data_type)), ] diff --git a/tools/mo/openvino/tools/mo/utils/cli_parser.py b/tools/mo/openvino/tools/mo/utils/cli_parser.py index 5e79c75ef9c..6fa94057830 100644 --- a/tools/mo/openvino/tools/mo/utils/cli_parser.py +++ b/tools/mo/openvino/tools/mo/utils/cli_parser.py @@ -257,10 +257,12 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None): 'models. Model Optimizer performs necessary transformations to convert the shape to ' 'the layout required by Inference Engine (N,C,H,W). The shape could contain ' 'undefined dimensions (-1) and should fit the dimensions defined in the input ' - 'operation of the graph. If there are multiple inputs in the model, --input_shape ' - 'should contain definition of shape for each input separated by a comma, for ' - 'example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D shapes. ' - 'Alternatively, specify shapes with the --input option.') + 'operation of the graph. Boundaries of undefined dimension can be specified with ' + 'ellipsis, for example [1,1..10,128,128]. One boundary can be undefined, for ' + 'example [1,..100] or [1,3,1..,1..]. If there are multiple inputs in the model, ' + '--input_shape should contain definition of shape for each input separated by a ' + 'comma, for example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D ' + 'shapes. Alternatively, specify shapes with the --input option.') common_group.add_argument('--scale', '-s', type=float, help='All input values coming from original network inputs will be ' + @@ -778,12 +780,12 @@ def remove_shape_from_input_value(input_value: str): :return: string without shape specification """ assert '->' not in input_value, 'The function should not be called for input_value with constant value specified' - return re.sub(r'[(\[]([0-9 -]*)[)\]]', '', input_value) + return re.sub(r'[(\[]([0-9\.? -]*)[)\]]', '', input_value) def get_shape_from_input_value(input_value: str): """ - Returns the numpy array with shape corresponding to the shape specified in the input value string + Returns the list of tuples corresponding to the shape specified in the input value string :param input_value: string passed as input to the --input command line parameter :return: the corresponding shape and None if the shape is not specified in the input value """ @@ -791,11 +793,11 @@ def get_shape_from_input_value(input_value: str): input_value = input_value.split('->')[0] # parse shape - shape = re.findall(r'[(\[]([0-9 -]*)[)\]]', input_value) + shape = re.findall(r'[(\[]([0-9\.\? -]+)[)\]]', input_value) if len(shape) == 0: shape = None elif len(shape) == 1: - shape = np.fromstring(shape[0], dtype=np.int64, sep=' ') + shape = tuple(map(parse_dimension, shape[0].split(' '))) else: raise Error("Wrong syntax to specify shape. Use --input " "\"node_name[shape]->value\"") @@ -854,6 +856,11 @@ def parse_input_value(input_value: str): shape = get_shape_from_input_value(input_value.split('->')[0]) value_size = np.prod(len(value)) if isinstance(value, list) else 1 + if value is not None and shape is not None: + for dim in shape: + if isinstance(dim, tuple) or dim == -1: + raise Error("Cannot freeze input with dynamic shape: {}".format(shape)) + if shape is not None and value is not None and np.prod(shape) != value_size: raise Error("The shape '{}' of the input node '{}' does not correspond to the number of elements '{}' in the " "value: {}".format(shape, node_name, value_size, value)) @@ -1043,6 +1050,33 @@ def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_ return placeholder_values, input_node_names +def parse_dimension(dim: str): + if '..' in dim: + numbers_reg = r'^[0-9]+$' + dims = dim.split('..') + match_res0 = re.match(numbers_reg, dims[0]) + match_res1 = re.match(numbers_reg, dims[1]) + if len(dims[0].strip()) > 0 and match_res0 is None: + Error("Incorrect min value of dimension '{}'".format(dims[0])) + if len(dims[1].strip()) > 0 and match_res1 is None: + Error("Incorrect max value of dimension '{}'".format(dims[1])) + + min_val = np.int64(dims[0]) if match_res0 else np.int64(0) + max_val = np.int64(dims[1]) if match_res1 else np.iinfo(np.int64).max + assert min_val >= 0, "Incorrect min value of the dimension {}".format(dim) + + if min_val == np.int64(0) and max_val == np.iinfo(np.int64).max: + return np.int64(-1) + + assert min_val < max_val, "Min value should be less than max value. Got min value: {}, " \ + "max value: {}".format(min_val, max_val) + + return min_val, max_val + if '?' in dim: + return np.int64(-1) + return np.int64(dim) + + def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None): """ Parses input layers names and input shapes from the cli and returns the parsed object. @@ -1055,16 +1089,18 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No E.g. 'inp1,inp2', 'node_name1[shape1]->value1,node_name2[shape2]->value2' argv_input_shape string with a list of input shapes: either an empty string, or tuples separated with comma. - E.g. '(1,2),(3,4)'. - Only positive integers are accepted except -1, which can be on any position in a shape. + E.g. '[1,2],[3,4]'. + Only positive integers are accepted. + '?' marks dynamic dimension. + Partial shape is specified with ellipsis. E.g. '[1..10,2,3]' argv_batch integer that overrides batch size in input shape Returns ------- - parsed shapes in form of {'name of input':ndarray} if names of inputs are provided with shapes + parsed shapes in form of {'name of input':tuple} if names of inputs are provided with shapes parsed shapes in form of {'name of input':None} if names of inputs are provided without shapes - ndarray if only one shape is provided and no input name + tuple if only one shape is provided and no input name None if neither shape nor input were provided """ if argv_input_shape and argv_batch: @@ -1099,7 +1135,8 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No inputs = list() placeholder_shapes = None - first_digit_reg = r'([0-9 ]+|-1)' + range_reg = r'([0-9]*\.\.[0-9]*)' + first_digit_reg = r'([0-9 ]+|-1|\?|{})'.format(range_reg) next_digits_reg = r'(,{})*'.format(first_digit_reg) tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg, first_digit_reg, next_digits_reg) @@ -1107,7 +1144,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg) if not re.match(full_reg, argv_input_shape): raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape) - shapes = re.findall(r'[(\[]([0-9, -]+)[)\]]', argv_input_shape) + shapes = re.findall(r'[(\[]([0-9,\.\? -]+)[)\]]', argv_input_shape) if argv_input: inputs = argv_input.split(',') @@ -1118,14 +1155,22 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No if len(shapes) > 1: raise Error('Please provide input layer names for input layer shapes. ' + refer_to_faq_msg(58)) else: - placeholder_shapes = np.fromstring(shapes[0], dtype=np.int64, sep=',') + placeholder_shapes = tuple(map(parse_dimension, shapes[0].split(','))) # check if number of shapes does not match number of passed inputs elif argv_input and (len(shapes) == len(inputs) or len(shapes) == 0): # clean inputs from values for freezing - inputs = list(map(lambda x: x.split('->')[0], inputs)) - placeholder_shapes = dict(zip_longest(inputs, - map(lambda x: np.fromstring(x, dtype=np.int64, - sep=',') if x else None, shapes))) + inputs_without_value = list(map(lambda x: x.split('->')[0], inputs)) + placeholder_shapes = dict(zip_longest(inputs_without_value, + map(lambda x: tuple(map(parse_dimension, x.split(','))) if x else None, + shapes))) + for inp in inputs: + if '->' not in inp: + continue + shape = placeholder_shapes[inp.split('->')[0]] + for dim in shape: + if isinstance(dim, tuple) or dim == -1: + raise Error("Cannot freeze input with dynamic shape: {}".format(shape)) + elif argv_input: raise Error('Please provide each input layers with an input layer shape. ' + refer_to_faq_msg(58)) diff --git a/tools/mo/openvino/tools/mo/utils/ir_reader/extenders/parameter_extender.py b/tools/mo/openvino/tools/mo/utils/ir_reader/extenders/parameter_extender.py index bebdc0ba19e..23a04c6e488 100644 --- a/tools/mo/openvino/tools/mo/utils/ir_reader/extenders/parameter_extender.py +++ b/tools/mo/openvino/tools/mo/utils/ir_reader/extenders/parameter_extender.py @@ -18,5 +18,7 @@ class Parameter_extender(Extender): op.shape = int64_array([]) else: Extender.attr_to_list(op, 'shape') - if -1 in op.shape: - op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape]) + for i, dim in enumerate(op.shape): + if dim == -1 or (isinstance(dim, str) and ".." in dim): + op.shape[i] = -1 + op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape]) diff --git a/tools/mo/unit_tests/mo/front/extractor_test.py b/tools/mo/unit_tests/mo/front/extractor_test.py index bfc0001e223..6a09e7320e2 100644 --- a/tools/mo/unit_tests/mo/front/extractor_test.py +++ b/tools/mo/unit_tests/mo/front/extractor_test.py @@ -537,11 +537,11 @@ class TestUserDataRepack(unittest.TestCase): def test_error(self): graph = build_graph(self.nodes, self.edges) - self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None) + self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None) def test_error_2(self): graph = build_graph(self.nodes, self.edges) - self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None) + self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None) def test_error_3(self): graph = build_graph(self.nodes, self.edges) @@ -549,7 +549,7 @@ class TestUserDataRepack(unittest.TestCase): def test_input_and_freeze(self): graph = build_graph(self.nodes, self.edges) - shape_1 = np.array([1, 160, 160, 3]) + shape_1 = tuple([1, 160, 160, 3]) input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True}) self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]}) self.assertDictEqual(freeze_placeholder, {'B': True}) diff --git a/tools/mo/unit_tests/mo/utils/cli_parser_test.py b/tools/mo/unit_tests/mo/utils/cli_parser_test.py index eb6a917e2e2..d9e5b3e1d69 100644 --- a/tools/mo/unit_tests/mo/utils/cli_parser_test.py +++ b/tools/mo/unit_tests/mo/utils/cli_parser_test.py @@ -778,6 +778,153 @@ class TestShapesParsing(unittest.TestCase): input_shapes = "(12,4,1),(4,-6,8)" self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes) + def test_get_shapes_several_inputs_several_partial_shapes(self): + argv_input = "inp1,inp2" + input_shapes = "(1,..22,1..100,?), (-1,45..,7,1)" + result, _ = get_placeholder_shapes(argv_input, input_shapes) + exp_res = {'inp1': (1, (0, 22), (1, 100), -1), 'inp2': (-1, (45, np.iinfo(np.int64).max), 7, 1)} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + + def test_get_shapes_several_inputs_several_partial_shapes2(self): + # shapes specified using --input command line parameter and no values + argv_input = "inp1[1 ? 50..100 123],inp2[-1 45.. ..7 1]" + result, _ = get_placeholder_shapes(argv_input, None) + exp_res = {'inp1': (1, -1, (50, 100), 123), 'inp2': (-1, (45,np.iinfo(np.int64).max), (0, 7), 1)} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) + placeholder_values_ref = {} + input_node_names_ref = "inp1,inp2" + self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) + for i in placeholder_values_ref.keys(): + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + + def test_get_shapes_several_inputs_several_partial_shapes3(self): + # shapes and value for freezing specified using --input command line parameter + argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[5]->[1.0 1.0 2.0 3.0 5.0]" + result, _ = get_placeholder_shapes(argv_input, None) + exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} + input_node_names_ref = "inp1,inp2,inp3" + self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) + for i in placeholder_values_ref.keys(): + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + + def test_get_shapes_several_inputs_several_partial_shapes4(self): + # shapes specified using --input_shape and values for freezing using --input command line parameter + argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]" + input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)" + result, _ = get_placeholder_shapes(argv_input, input_shapes) + exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} + input_node_names_ref = "inp1,inp2,inp3" + self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) + for i in placeholder_values_ref.keys(): + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + self.assertEqual(input_node_names_ref, input_node_names_res) + + def test_get_shapes_several_inputs_several_partial_shapes5(self): + # some values for freezing specified using --freeze_placeholder_with_value + argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]" + input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)" + argv_freeze_placeholder_with_value = "inp2->[5.0 7.0 3.0],inp4->[100.0 200.0]" + + result, _ = get_placeholder_shapes(argv_input, input_shapes) + exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value) + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],), + 'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])} + input_node_names_ref = "inp1,inp2,inp3" + self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys()))) + for i in placeholder_values_ref.keys(): + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + self.assertEqual(input_node_names_ref, input_node_names_res) + + def test_get_shapes_several_inputs_several_partial_shapes6(self): + # 0D value for freezing specified using --input command line parameter without shape + argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3->False" + result, _ = get_placeholder_shapes(argv_input, None) + exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False} + self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) + for i in placeholder_values_ref.keys(): + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + + def test_get_shapes_several_inputs_several_partial_shapes7(self): + # 0D shape and value for freezing specified using --input command line parameter + argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[]->True" + result, _ = get_placeholder_shapes(argv_input, None) + exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape} + self.assertEqual(list(exp_res.keys()), list(result.keys())) + for i in exp_res.keys(): + npt.assert_array_equal(result[i], exp_res[i]) + placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) + placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True} + self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) + for i in placeholder_values_ref.keys(): + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + + def test_get_shapes_and_data_types_partial_shape_with_input_port(self): + argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],0:inp2[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]" + result_shapes, result_data_types = get_placeholder_shapes(argv_input, "") + ref_result_shapes = {'inp1:1': np.array([3, 1]), '0:inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])} + ref_result_data_types = {'0:inp2': np.int32, 'inp3:4': np.float32} + self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) + for i in ref_result_shapes.keys(): + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) + for i in ref_result_data_types.keys(): + np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) + + def test_get_shapes_and_data_types_partial_shape_with_output_port(self): + argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],inp2:3[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]" + result_shapes, result_data_types = get_placeholder_shapes(argv_input, "") + ref_result_shapes = {'inp1:1': np.array([3, 1]), 'inp2:3': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])} + ref_result_data_types = {'inp2:3': np.int32, 'inp3:4': np.float32} + self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) + for i in ref_result_shapes.keys(): + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) + for i in ref_result_data_types.keys(): + np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) + + def test_partial_shapes_negative_case(self): + argv_input = "inp1" + input_shapes = "[6754fg..23ed]" + self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes) + + def test_partial_shapes_freeze_dynamic_negative_case1(self): + argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]" + self.assertRaises(Error, get_placeholder_shapes, argv_input, "") + + def test_partial_shapes_freeze_dynamic_negative_case2(self): + argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]" + self.assertRaises(Error, get_placeholder_shapes, argv_input, "") + + def test_partial_shapes_freeze_dynamic_negative_case3(self): + # some values for freezing specified using --freeze_placeholder_with_value + argv_input = "inp1->[1.0 2.0 3.0]" + input_shapes = "[3,1..10]" + self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes) + class TestModelNameParsing(unittest.TestCase): def test_model_name_ideal(self):