Partial shape serialization MO (#8927)
* Support of partial shape in MO. * Added partial shapes serialization in MO. * Removed breckets. * Fixed parameter extender in IR reader. * Removed wrong changes. * Added checks, added tests. * Updated help. * Fixed import.
This commit is contained in:
parent
eab49eec8b
commit
3cb27715e2
@ -605,7 +605,7 @@ def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, n
|
||||
if ph_id not in _input_shapes]
|
||||
else:
|
||||
# np.ndarray is a shape. User provided only --input_shape key
|
||||
assert isinstance(input_user_shapes, np.ndarray)
|
||||
assert isinstance(input_user_shapes, tuple)
|
||||
if len(placeholders_ids) == 1:
|
||||
# There is only one placeholder in the original network
|
||||
_input_shapes[placeholders_ids[0]].append({'shape': input_user_shapes, 'port': None})
|
||||
@ -861,7 +861,7 @@ def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, por
|
||||
|
||||
|
||||
def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
||||
shape=None, data_type=None, is_out_port: bool = False):
|
||||
shape=None, user_shape=None, data_type=None, is_out_port: bool = False):
|
||||
"""
|
||||
This function adds Input node to node with id==node_id to specified port (in or out defined with is_out_port).
|
||||
:param graph: graph to operate on.
|
||||
@ -869,6 +869,8 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
||||
:param port: number of port of node_id node for adding input node.
|
||||
:param data: flag that define whether data nodes is needed or not.
|
||||
:param shape: shape for new input node.
|
||||
:param user_shape: shape provided by user which may contain boundaries of dynamic dimension.
|
||||
:param data_type: data type of input node.
|
||||
:param is_out_port: flag that define whether port is output port or not.
|
||||
:return: id of new Input operation
|
||||
"""
|
||||
@ -876,7 +878,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
if data_type is None:
|
||||
data_type = np.float32
|
||||
input_op = Parameter(graph, dict(shape=shape, data_type=data_type, initial_node_name=node_id,
|
||||
input_op = Parameter(graph, dict(shape=shape, user_shape=user_shape, data_type=data_type, initial_node_name=node_id,
|
||||
name=get_new_placeholder_name(node_id, is_out_port, port)))
|
||||
|
||||
fw_name = Node(graph, node_id).soft_get('name')
|
||||
@ -901,7 +903,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
||||
|
||||
|
||||
def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node, port: int, node_id: str,
|
||||
shape: np.array, data_type,
|
||||
shape: np.array, user_shape: tuple, data_type,
|
||||
inputs: list, edges_to_remove: list):
|
||||
n_inputs = len(smart_node.in_nodes())
|
||||
if n_inputs > 1 and port is None:
|
||||
@ -912,7 +914,7 @@ def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node,
|
||||
port = port if port is not None else 0
|
||||
edges_to_remove.append((smart_node.in_node(port).id, smart_node.id))
|
||||
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
|
||||
shape=shape, data_type=data_type))
|
||||
shape=shape, user_shape=user_shape, data_type=data_type))
|
||||
|
||||
|
||||
def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node, port:int, node_id: str,
|
||||
@ -934,13 +936,14 @@ def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node,
|
||||
edges_to_remove.append((in_node.id, node_id))
|
||||
|
||||
|
||||
def add_input_ops_helper_before_infer_output_port(graph: Graph, port:int, node_id: str,
|
||||
shape: np.array, data_type, inputs: list, edges_to_remove: list):
|
||||
def add_input_ops_helper_before_infer_output_port(graph: Graph, port: int, node_id: str,
|
||||
shape: np.array, user_shape: tuple, data_type: tuple,
|
||||
inputs: list, edges_to_remove: list):
|
||||
for u, v, edge_attrs in graph.out_edges(node_id, data=True):
|
||||
if edge_attrs['out'] == port:
|
||||
edges_to_remove.append((u, v)) # we need to remove all edges from this port
|
||||
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
|
||||
shape=shape, data_type=data_type, is_out_port=True))
|
||||
shape=shape, user_shape=user_shape, data_type=data_type, is_out_port=True))
|
||||
|
||||
|
||||
def add_input_ops_helper_after_infer_output_port(graph: Graph, smart_node: Node, port:int, node_id: str,
|
||||
@ -987,8 +990,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
|
||||
|
||||
is_out_port = 'out' in port_and_shape_info # by default we assume input port or input node without port
|
||||
shape = port_and_shape_info['shape'] if 'shape' in port_and_shape_info else None
|
||||
user_shape = None
|
||||
if shape is not None:
|
||||
shape = shape_array([dim if dim >= 0 else dynamic_dimension_value for dim in shape])
|
||||
user_shape = shape
|
||||
shape = shape_array(
|
||||
[dim if type(dim) != tuple and dim >= 0 else dynamic_dimension_value for dim in shape])
|
||||
data_type = port_and_shape_info['data_type'] if 'data_type' in port_and_shape_info else None
|
||||
smart_node = Node(graph, node_id)
|
||||
|
||||
@ -1016,6 +1022,7 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
|
||||
refer_to_faq_msg(28), node_id, port)
|
||||
if shape is not None:
|
||||
smart_node['shape'] = shape
|
||||
smart_node['user_shape'] = user_shape
|
||||
if data_type is not None:
|
||||
smart_node['data_type'] = data_type
|
||||
inputs.append(node_id)
|
||||
@ -1027,10 +1034,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
|
||||
continue
|
||||
# We cut with shapes provided by user and there is no need to wait till infer
|
||||
if is_out_port:
|
||||
add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, data_type, inputs,
|
||||
edges_to_remove)
|
||||
add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, user_shape,
|
||||
data_type, inputs, edges_to_remove)
|
||||
else:
|
||||
add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape, data_type, inputs,
|
||||
add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape,
|
||||
user_shape, data_type, inputs,
|
||||
edges_to_remove)
|
||||
else:
|
||||
# We cut after infer and we need inferred shapes in nodes
|
||||
|
@ -2,10 +2,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging as log
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, unmask_shape, shape_array, dynamic_dimension_value
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, shape_array, \
|
||||
dynamic_dimension_value
|
||||
from openvino.tools.mo.graph.graph import Graph
|
||||
from openvino.tools.mo.middle.passes.infer import partial_infer
|
||||
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
|
||||
|
||||
class PartialInfer(MiddleReplacementPattern):
|
||||
@ -25,7 +27,7 @@ class PartialInfer(MiddleReplacementPattern):
|
||||
param_shape = parameter.soft_get('shape', shape_array(dynamic_dimension_value))
|
||||
if not is_fully_defined(param_shape):
|
||||
parameter_name = parameter.soft_get('name', parameter.id)
|
||||
dynamic_inputs[parameter_name] = param_shape
|
||||
dynamic_inputs[parameter_name] = parameter
|
||||
if dynamic_inputs:
|
||||
log.error('The model contains input(s) with partially defined shapes: {}. '
|
||||
'Starting from the 2022.1 release the Model Optimizer can generate an IR with partially defined '
|
||||
@ -34,7 +36,7 @@ class PartialInfer(MiddleReplacementPattern):
|
||||
'call "reshape" method in the Inference Engine and specify static input shapes. For optimal '
|
||||
'performance, it is still recommended to update input shapes with fixed ones using "--input" or '
|
||||
'"--input_shape" command-line parameters.'
|
||||
.format(','.join('name="{}" shape="{}"'.format(name, unmask_shape(shape))
|
||||
for name, shape in dynamic_inputs.items())),
|
||||
.format(','.join('name="{}" shape="{}"'.format(name, Parameter.shape_serialize(parameter))
|
||||
for name, parameter in dynamic_inputs.items())),
|
||||
extra={'is_warning': True})
|
||||
partial_infer(graph)
|
||||
|
@ -118,7 +118,7 @@ def fe_input_user_data_repack(input_model: InputModel, input_user_shapes: [None,
|
||||
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
|
||||
else:
|
||||
_input_shapes.append({'node': node, 'shape': shape})
|
||||
elif isinstance(input_user_shapes, np.ndarray):
|
||||
elif isinstance(input_user_shapes, tuple):
|
||||
model_inputs = input_model.get_inputs()
|
||||
assert len(model_inputs) == 1
|
||||
_input_shapes.append({'node': model_inputs[0], 'shape': input_user_shapes})
|
||||
|
@ -12,6 +12,8 @@ from openvino.runtime import Dimension, PartialShape # pylint: disable=no
|
||||
from openvino.frontend import FrontEnd, Place # pylint: disable=no-name-in-module,import-error
|
||||
from openvino.runtime.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
"""
|
||||
@ -66,7 +68,7 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
for user_shape in user_shapes:
|
||||
if user_shape.get('shape') is not None:
|
||||
input_model.set_partial_shape(
|
||||
user_shape['node'], PartialShape(user_shape['shape']))
|
||||
user_shape['node'], partial_shape_from_tuple(user_shape['shape']))
|
||||
if user_shape.get('data_type') is not None:
|
||||
data_type = get_element_type(user_shape['data_type'])
|
||||
log.debug('Set data type: {}'.format(data_type))
|
||||
@ -97,3 +99,16 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
|
||||
ngraph_function = moc_front_end.convert(input_model)
|
||||
return ngraph_function
|
||||
|
||||
|
||||
def partial_shape_from_tuple(shape: tuple):
|
||||
new_shape = []
|
||||
for dim in shape:
|
||||
if isinstance(dim, tuple):
|
||||
assert len(dim) == 2, "Incorrect boundaries of dimension {} in shape {}".format(dim, shape)
|
||||
assert dim[0] >= 0, "Incorrect min value of dimension {} in shape".format(dim, shape)
|
||||
new_shape.append(Dimension(dim[0], dim[1]))
|
||||
else:
|
||||
assert isinstance(dim, np.int64), "Incorrect type of dimension {} in shape".format(dim, shape)
|
||||
new_shape.append(Dimension(dim))
|
||||
return PartialShape(new_shape)
|
||||
|
@ -26,6 +26,7 @@ class Parameter(Op):
|
||||
'type_infer': self.type_infer,
|
||||
|
||||
'out_ports_count': 1,
|
||||
'user_shape': None,
|
||||
}
|
||||
if 'data_type' not in attrs:
|
||||
mandatory_props['data_type'] = np.float32
|
||||
@ -35,9 +36,25 @@ class Parameter(Op):
|
||||
def type_infer(node):
|
||||
node.out_port(0).set_data_type(node.data_type)
|
||||
|
||||
@staticmethod
|
||||
def shape_serialize(node):
|
||||
def serialize_dimension(dim: [tuple, np.int64]):
|
||||
if type(dim) == tuple:
|
||||
assert len(dim) == 2, "Unable to serialize shape {} in node {}".format(node.soft_get('user_shape'),
|
||||
node.soft_get('name', node.id))
|
||||
min_str = str(dim[0]) if dim[0] > 0 else ""
|
||||
max_str = str(dim[1]) if dim[1] < np.iinfo(np.int64).max else ""
|
||||
return min_str + ".." + max_str
|
||||
return str(dim)
|
||||
|
||||
if not node.has_valid('user_shape'):
|
||||
return ','.join([str(i) for i in unmask_shape(node.shape)])
|
||||
shape = node.soft_get('user_shape')
|
||||
return ','.join(map(serialize_dimension, shape))
|
||||
|
||||
def supported_attrs(self):
|
||||
return [
|
||||
('shape', lambda node: ','.join([str(i) for i in unmask_shape(node.shape)])),
|
||||
('shape', lambda node: self.shape_serialize(node)),
|
||||
('element_type', lambda node: np_data_type_to_destination_type(node.data_type)),
|
||||
]
|
||||
|
||||
|
@ -257,10 +257,12 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
||||
'models. Model Optimizer performs necessary transformations to convert the shape to '
|
||||
'the layout required by Inference Engine (N,C,H,W). The shape could contain '
|
||||
'undefined dimensions (-1) and should fit the dimensions defined in the input '
|
||||
'operation of the graph. If there are multiple inputs in the model, --input_shape '
|
||||
'should contain definition of shape for each input separated by a comma, for '
|
||||
'example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D shapes. '
|
||||
'Alternatively, specify shapes with the --input option.')
|
||||
'operation of the graph. Boundaries of undefined dimension can be specified with '
|
||||
'ellipsis, for example [1,1..10,128,128]. One boundary can be undefined, for '
|
||||
'example [1,..100] or [1,3,1..,1..]. If there are multiple inputs in the model, '
|
||||
'--input_shape should contain definition of shape for each input separated by a '
|
||||
'comma, for example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D '
|
||||
'shapes. Alternatively, specify shapes with the --input option.')
|
||||
common_group.add_argument('--scale', '-s',
|
||||
type=float,
|
||||
help='All input values coming from original network inputs will be ' +
|
||||
@ -778,12 +780,12 @@ def remove_shape_from_input_value(input_value: str):
|
||||
:return: string without shape specification
|
||||
"""
|
||||
assert '->' not in input_value, 'The function should not be called for input_value with constant value specified'
|
||||
return re.sub(r'[(\[]([0-9 -]*)[)\]]', '', input_value)
|
||||
return re.sub(r'[(\[]([0-9\.? -]*)[)\]]', '', input_value)
|
||||
|
||||
|
||||
def get_shape_from_input_value(input_value: str):
|
||||
"""
|
||||
Returns the numpy array with shape corresponding to the shape specified in the input value string
|
||||
Returns the list of tuples corresponding to the shape specified in the input value string
|
||||
:param input_value: string passed as input to the --input command line parameter
|
||||
:return: the corresponding shape and None if the shape is not specified in the input value
|
||||
"""
|
||||
@ -791,11 +793,11 @@ def get_shape_from_input_value(input_value: str):
|
||||
input_value = input_value.split('->')[0]
|
||||
|
||||
# parse shape
|
||||
shape = re.findall(r'[(\[]([0-9 -]*)[)\]]', input_value)
|
||||
shape = re.findall(r'[(\[]([0-9\.\? -]+)[)\]]', input_value)
|
||||
if len(shape) == 0:
|
||||
shape = None
|
||||
elif len(shape) == 1:
|
||||
shape = np.fromstring(shape[0], dtype=np.int64, sep=' ')
|
||||
shape = tuple(map(parse_dimension, shape[0].split(' ')))
|
||||
else:
|
||||
raise Error("Wrong syntax to specify shape. Use --input "
|
||||
"\"node_name[shape]->value\"")
|
||||
@ -854,6 +856,11 @@ def parse_input_value(input_value: str):
|
||||
shape = get_shape_from_input_value(input_value.split('->')[0])
|
||||
value_size = np.prod(len(value)) if isinstance(value, list) else 1
|
||||
|
||||
if value is not None and shape is not None:
|
||||
for dim in shape:
|
||||
if isinstance(dim, tuple) or dim == -1:
|
||||
raise Error("Cannot freeze input with dynamic shape: {}".format(shape))
|
||||
|
||||
if shape is not None and value is not None and np.prod(shape) != value_size:
|
||||
raise Error("The shape '{}' of the input node '{}' does not correspond to the number of elements '{}' in the "
|
||||
"value: {}".format(shape, node_name, value_size, value))
|
||||
@ -1043,6 +1050,33 @@ def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_
|
||||
return placeholder_values, input_node_names
|
||||
|
||||
|
||||
def parse_dimension(dim: str):
|
||||
if '..' in dim:
|
||||
numbers_reg = r'^[0-9]+$'
|
||||
dims = dim.split('..')
|
||||
match_res0 = re.match(numbers_reg, dims[0])
|
||||
match_res1 = re.match(numbers_reg, dims[1])
|
||||
if len(dims[0].strip()) > 0 and match_res0 is None:
|
||||
Error("Incorrect min value of dimension '{}'".format(dims[0]))
|
||||
if len(dims[1].strip()) > 0 and match_res1 is None:
|
||||
Error("Incorrect max value of dimension '{}'".format(dims[1]))
|
||||
|
||||
min_val = np.int64(dims[0]) if match_res0 else np.int64(0)
|
||||
max_val = np.int64(dims[1]) if match_res1 else np.iinfo(np.int64).max
|
||||
assert min_val >= 0, "Incorrect min value of the dimension {}".format(dim)
|
||||
|
||||
if min_val == np.int64(0) and max_val == np.iinfo(np.int64).max:
|
||||
return np.int64(-1)
|
||||
|
||||
assert min_val < max_val, "Min value should be less than max value. Got min value: {}, " \
|
||||
"max value: {}".format(min_val, max_val)
|
||||
|
||||
return min_val, max_val
|
||||
if '?' in dim:
|
||||
return np.int64(-1)
|
||||
return np.int64(dim)
|
||||
|
||||
|
||||
def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None):
|
||||
"""
|
||||
Parses input layers names and input shapes from the cli and returns the parsed object.
|
||||
@ -1055,16 +1089,18 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
E.g. 'inp1,inp2', 'node_name1[shape1]->value1,node_name2[shape2]->value2'
|
||||
argv_input_shape
|
||||
string with a list of input shapes: either an empty string, or tuples separated with comma.
|
||||
E.g. '(1,2),(3,4)'.
|
||||
Only positive integers are accepted except -1, which can be on any position in a shape.
|
||||
E.g. '[1,2],[3,4]'.
|
||||
Only positive integers are accepted.
|
||||
'?' marks dynamic dimension.
|
||||
Partial shape is specified with ellipsis. E.g. '[1..10,2,3]'
|
||||
argv_batch
|
||||
integer that overrides batch size in input shape
|
||||
|
||||
Returns
|
||||
-------
|
||||
parsed shapes in form of {'name of input':ndarray} if names of inputs are provided with shapes
|
||||
parsed shapes in form of {'name of input':tuple} if names of inputs are provided with shapes
|
||||
parsed shapes in form of {'name of input':None} if names of inputs are provided without shapes
|
||||
ndarray if only one shape is provided and no input name
|
||||
tuple if only one shape is provided and no input name
|
||||
None if neither shape nor input were provided
|
||||
"""
|
||||
if argv_input_shape and argv_batch:
|
||||
@ -1099,7 +1135,8 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
inputs = list()
|
||||
placeholder_shapes = None
|
||||
|
||||
first_digit_reg = r'([0-9 ]+|-1)'
|
||||
range_reg = r'([0-9]*\.\.[0-9]*)'
|
||||
first_digit_reg = r'([0-9 ]+|-1|\?|{})'.format(range_reg)
|
||||
next_digits_reg = r'(,{})*'.format(first_digit_reg)
|
||||
tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg,
|
||||
first_digit_reg, next_digits_reg)
|
||||
@ -1107,7 +1144,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg)
|
||||
if not re.match(full_reg, argv_input_shape):
|
||||
raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape)
|
||||
shapes = re.findall(r'[(\[]([0-9, -]+)[)\]]', argv_input_shape)
|
||||
shapes = re.findall(r'[(\[]([0-9,\.\? -]+)[)\]]', argv_input_shape)
|
||||
|
||||
if argv_input:
|
||||
inputs = argv_input.split(',')
|
||||
@ -1118,14 +1155,22 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
if len(shapes) > 1:
|
||||
raise Error('Please provide input layer names for input layer shapes. ' + refer_to_faq_msg(58))
|
||||
else:
|
||||
placeholder_shapes = np.fromstring(shapes[0], dtype=np.int64, sep=',')
|
||||
placeholder_shapes = tuple(map(parse_dimension, shapes[0].split(',')))
|
||||
# check if number of shapes does not match number of passed inputs
|
||||
elif argv_input and (len(shapes) == len(inputs) or len(shapes) == 0):
|
||||
# clean inputs from values for freezing
|
||||
inputs = list(map(lambda x: x.split('->')[0], inputs))
|
||||
placeholder_shapes = dict(zip_longest(inputs,
|
||||
map(lambda x: np.fromstring(x, dtype=np.int64,
|
||||
sep=',') if x else None, shapes)))
|
||||
inputs_without_value = list(map(lambda x: x.split('->')[0], inputs))
|
||||
placeholder_shapes = dict(zip_longest(inputs_without_value,
|
||||
map(lambda x: tuple(map(parse_dimension, x.split(','))) if x else None,
|
||||
shapes)))
|
||||
for inp in inputs:
|
||||
if '->' not in inp:
|
||||
continue
|
||||
shape = placeholder_shapes[inp.split('->')[0]]
|
||||
for dim in shape:
|
||||
if isinstance(dim, tuple) or dim == -1:
|
||||
raise Error("Cannot freeze input with dynamic shape: {}".format(shape))
|
||||
|
||||
elif argv_input:
|
||||
raise Error('Please provide each input layers with an input layer shape. ' + refer_to_faq_msg(58))
|
||||
|
||||
|
@ -18,5 +18,7 @@ class Parameter_extender(Extender):
|
||||
op.shape = int64_array([])
|
||||
else:
|
||||
Extender.attr_to_list(op, 'shape')
|
||||
if -1 in op.shape:
|
||||
op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape])
|
||||
for i, dim in enumerate(op.shape):
|
||||
if dim == -1 or (isinstance(dim, str) and ".." in dim):
|
||||
op.shape[i] = -1
|
||||
op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape])
|
||||
|
@ -537,11 +537,11 @@ class TestUserDataRepack(unittest.TestCase):
|
||||
|
||||
def test_error(self):
|
||||
graph = build_graph(self.nodes, self.edges)
|
||||
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
|
||||
self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None)
|
||||
|
||||
def test_error_2(self):
|
||||
graph = build_graph(self.nodes, self.edges)
|
||||
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
|
||||
self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None)
|
||||
|
||||
def test_error_3(self):
|
||||
graph = build_graph(self.nodes, self.edges)
|
||||
@ -549,7 +549,7 @@ class TestUserDataRepack(unittest.TestCase):
|
||||
|
||||
def test_input_and_freeze(self):
|
||||
graph = build_graph(self.nodes, self.edges)
|
||||
shape_1 = np.array([1, 160, 160, 3])
|
||||
shape_1 = tuple([1, 160, 160, 3])
|
||||
input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True})
|
||||
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]})
|
||||
self.assertDictEqual(freeze_placeholder, {'B': True})
|
||||
|
@ -778,6 +778,153 @@ class TestShapesParsing(unittest.TestCase):
|
||||
input_shapes = "(12,4,1),(4,-6,8)"
|
||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes(self):
|
||||
argv_input = "inp1,inp2"
|
||||
input_shapes = "(1,..22,1..100,?), (-1,45..,7,1)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': (1, (0, 22), (1, 100), -1), 'inp2': (-1, (45, np.iinfo(np.int64).max), 7, 1)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes2(self):
|
||||
# shapes specified using --input command line parameter and no values
|
||||
argv_input = "inp1[1 ? 50..100 123],inp2[-1 45.. ..7 1]"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (1, -1, (50, 100), 123), 'inp2': (-1, (45,np.iinfo(np.int64).max), (0, 7), 1)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {}
|
||||
input_node_names_ref = "inp1,inp2"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes3(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[5]->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes4(self):
|
||||
# shapes specified using --input_shape and values for freezing using --input command line parameter
|
||||
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
|
||||
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes5(self):
|
||||
# some values for freezing specified using --freeze_placeholder_with_value
|
||||
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
|
||||
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
|
||||
argv_freeze_placeholder_with_value = "inp2->[5.0 7.0 3.0],inp4->[100.0 200.0]"
|
||||
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
|
||||
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
|
||||
for i in placeholder_values_ref.keys():
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes6(self):
|
||||
# 0D value for freezing specified using --input command line parameter without shape
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3->False"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False}
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes7(self):
|
||||
# 0D shape and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[]->True"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True}
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_input_port(self):
|
||||
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],0:inp2[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'inp1:1': np.array([3, 1]), '0:inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
|
||||
ref_result_data_types = {'0:inp2': np.int32, 'inp3:4': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_output_port(self):
|
||||
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],inp2:3[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'inp1:1': np.array([3, 1]), 'inp2:3': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
|
||||
ref_result_data_types = {'inp2:3': np.int32, 'inp3:4': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_partial_shapes_negative_case(self):
|
||||
argv_input = "inp1"
|
||||
input_shapes = "[6754fg..23ed]"
|
||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case1(self):
|
||||
argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, "")
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case2(self):
|
||||
argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]"
|
||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, "")
|
||||
|
||||
def test_partial_shapes_freeze_dynamic_negative_case3(self):
|
||||
# some values for freezing specified using --freeze_placeholder_with_value
|
||||
argv_input = "inp1->[1.0 2.0 3.0]"
|
||||
input_shapes = "[3,1..10]"
|
||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||
|
||||
|
||||
class TestModelNameParsing(unittest.TestCase):
|
||||
def test_model_name_ideal(self):
|
||||
|
Loading…
Reference in New Issue
Block a user