Partial shape serialization MO (#8927)

* Support of partial shape in MO.

* Added partial shapes serialization in MO.

* Removed breckets.

* Fixed parameter extender in IR reader.

* Removed wrong changes.

* Added checks, added tests.

* Updated help.

* Fixed import.
This commit is contained in:
Anastasia Popova 2021-12-14 17:59:19 +03:00 committed by GitHub
parent eab49eec8b
commit 3cb27715e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 279 additions and 43 deletions

View File

@ -605,7 +605,7 @@ def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, n
if ph_id not in _input_shapes]
else:
# np.ndarray is a shape. User provided only --input_shape key
assert isinstance(input_user_shapes, np.ndarray)
assert isinstance(input_user_shapes, tuple)
if len(placeholders_ids) == 1:
# There is only one placeholder in the original network
_input_shapes[placeholders_ids[0]].append({'shape': input_user_shapes, 'port': None})
@ -861,7 +861,7 @@ def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, por
def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
shape=None, data_type=None, is_out_port: bool = False):
shape=None, user_shape=None, data_type=None, is_out_port: bool = False):
"""
This function adds Input node to node with id==node_id to specified port (in or out defined with is_out_port).
:param graph: graph to operate on.
@ -869,6 +869,8 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
:param port: number of port of node_id node for adding input node.
:param data: flag that define whether data nodes is needed or not.
:param shape: shape for new input node.
:param user_shape: shape provided by user which may contain boundaries of dynamic dimension.
:param data_type: data type of input node.
:param is_out_port: flag that define whether port is output port or not.
:return: id of new Input operation
"""
@ -876,7 +878,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
from openvino.tools.mo.ops.parameter import Parameter
if data_type is None:
data_type = np.float32
input_op = Parameter(graph, dict(shape=shape, data_type=data_type, initial_node_name=node_id,
input_op = Parameter(graph, dict(shape=shape, user_shape=user_shape, data_type=data_type, initial_node_name=node_id,
name=get_new_placeholder_name(node_id, is_out_port, port)))
fw_name = Node(graph, node_id).soft_get('name')
@ -901,7 +903,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node, port: int, node_id: str,
shape: np.array, data_type,
shape: np.array, user_shape: tuple, data_type,
inputs: list, edges_to_remove: list):
n_inputs = len(smart_node.in_nodes())
if n_inputs > 1 and port is None:
@ -912,7 +914,7 @@ def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node,
port = port if port is not None else 0
edges_to_remove.append((smart_node.in_node(port).id, smart_node.id))
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
shape=shape, data_type=data_type))
shape=shape, user_shape=user_shape, data_type=data_type))
def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node, port:int, node_id: str,
@ -934,13 +936,14 @@ def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node,
edges_to_remove.append((in_node.id, node_id))
def add_input_ops_helper_before_infer_output_port(graph: Graph, port:int, node_id: str,
shape: np.array, data_type, inputs: list, edges_to_remove: list):
def add_input_ops_helper_before_infer_output_port(graph: Graph, port: int, node_id: str,
shape: np.array, user_shape: tuple, data_type: tuple,
inputs: list, edges_to_remove: list):
for u, v, edge_attrs in graph.out_edges(node_id, data=True):
if edge_attrs['out'] == port:
edges_to_remove.append((u, v)) # we need to remove all edges from this port
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
shape=shape, data_type=data_type, is_out_port=True))
shape=shape, user_shape=user_shape, data_type=data_type, is_out_port=True))
def add_input_ops_helper_after_infer_output_port(graph: Graph, smart_node: Node, port:int, node_id: str,
@ -987,8 +990,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
is_out_port = 'out' in port_and_shape_info # by default we assume input port or input node without port
shape = port_and_shape_info['shape'] if 'shape' in port_and_shape_info else None
user_shape = None
if shape is not None:
shape = shape_array([dim if dim >= 0 else dynamic_dimension_value for dim in shape])
user_shape = shape
shape = shape_array(
[dim if type(dim) != tuple and dim >= 0 else dynamic_dimension_value for dim in shape])
data_type = port_and_shape_info['data_type'] if 'data_type' in port_and_shape_info else None
smart_node = Node(graph, node_id)
@ -1016,6 +1022,7 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
refer_to_faq_msg(28), node_id, port)
if shape is not None:
smart_node['shape'] = shape
smart_node['user_shape'] = user_shape
if data_type is not None:
smart_node['data_type'] = data_type
inputs.append(node_id)
@ -1027,10 +1034,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
continue
# We cut with shapes provided by user and there is no need to wait till infer
if is_out_port:
add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, data_type, inputs,
edges_to_remove)
add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, user_shape,
data_type, inputs, edges_to_remove)
else:
add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape, data_type, inputs,
add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape,
user_shape, data_type, inputs,
edges_to_remove)
else:
# We cut after infer and we need inferred shapes in nodes

View File

@ -2,10 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
import logging as log
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, unmask_shape, shape_array, dynamic_dimension_value
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, shape_array, \
dynamic_dimension_value
from openvino.tools.mo.graph.graph import Graph
from openvino.tools.mo.middle.passes.infer import partial_infer
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
from openvino.tools.mo.ops.parameter import Parameter
class PartialInfer(MiddleReplacementPattern):
@ -25,7 +27,7 @@ class PartialInfer(MiddleReplacementPattern):
param_shape = parameter.soft_get('shape', shape_array(dynamic_dimension_value))
if not is_fully_defined(param_shape):
parameter_name = parameter.soft_get('name', parameter.id)
dynamic_inputs[parameter_name] = param_shape
dynamic_inputs[parameter_name] = parameter
if dynamic_inputs:
log.error('The model contains input(s) with partially defined shapes: {}. '
'Starting from the 2022.1 release the Model Optimizer can generate an IR with partially defined '
@ -34,7 +36,7 @@ class PartialInfer(MiddleReplacementPattern):
'call "reshape" method in the Inference Engine and specify static input shapes. For optimal '
'performance, it is still recommended to update input shapes with fixed ones using "--input" or '
'"--input_shape" command-line parameters.'
.format(','.join('name="{}" shape="{}"'.format(name, unmask_shape(shape))
for name, shape in dynamic_inputs.items())),
.format(','.join('name="{}" shape="{}"'.format(name, Parameter.shape_serialize(parameter))
for name, parameter in dynamic_inputs.items())),
extra={'is_warning': True})
partial_infer(graph)

View File

@ -118,7 +118,7 @@ def fe_input_user_data_repack(input_model: InputModel, input_user_shapes: [None,
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
else:
_input_shapes.append({'node': node, 'shape': shape})
elif isinstance(input_user_shapes, np.ndarray):
elif isinstance(input_user_shapes, tuple):
model_inputs = input_model.get_inputs()
assert len(model_inputs) == 1
_input_shapes.append({'node': model_inputs[0], 'shape': input_user_shapes})

View File

@ -12,6 +12,8 @@ from openvino.runtime import Dimension, PartialShape # pylint: disable=no
from openvino.frontend import FrontEnd, Place # pylint: disable=no-name-in-module,import-error
from openvino.runtime.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error
import numpy as np
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
"""
@ -66,7 +68,7 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
for user_shape in user_shapes:
if user_shape.get('shape') is not None:
input_model.set_partial_shape(
user_shape['node'], PartialShape(user_shape['shape']))
user_shape['node'], partial_shape_from_tuple(user_shape['shape']))
if user_shape.get('data_type') is not None:
data_type = get_element_type(user_shape['data_type'])
log.debug('Set data type: {}'.format(data_type))
@ -97,3 +99,16 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
ngraph_function = moc_front_end.convert(input_model)
return ngraph_function
def partial_shape_from_tuple(shape: tuple):
new_shape = []
for dim in shape:
if isinstance(dim, tuple):
assert len(dim) == 2, "Incorrect boundaries of dimension {} in shape {}".format(dim, shape)
assert dim[0] >= 0, "Incorrect min value of dimension {} in shape".format(dim, shape)
new_shape.append(Dimension(dim[0], dim[1]))
else:
assert isinstance(dim, np.int64), "Incorrect type of dimension {} in shape".format(dim, shape)
new_shape.append(Dimension(dim))
return PartialShape(new_shape)

View File

@ -26,6 +26,7 @@ class Parameter(Op):
'type_infer': self.type_infer,
'out_ports_count': 1,
'user_shape': None,
}
if 'data_type' not in attrs:
mandatory_props['data_type'] = np.float32
@ -35,9 +36,25 @@ class Parameter(Op):
def type_infer(node):
node.out_port(0).set_data_type(node.data_type)
@staticmethod
def shape_serialize(node):
def serialize_dimension(dim: [tuple, np.int64]):
if type(dim) == tuple:
assert len(dim) == 2, "Unable to serialize shape {} in node {}".format(node.soft_get('user_shape'),
node.soft_get('name', node.id))
min_str = str(dim[0]) if dim[0] > 0 else ""
max_str = str(dim[1]) if dim[1] < np.iinfo(np.int64).max else ""
return min_str + ".." + max_str
return str(dim)
if not node.has_valid('user_shape'):
return ','.join([str(i) for i in unmask_shape(node.shape)])
shape = node.soft_get('user_shape')
return ','.join(map(serialize_dimension, shape))
def supported_attrs(self):
return [
('shape', lambda node: ','.join([str(i) for i in unmask_shape(node.shape)])),
('shape', lambda node: self.shape_serialize(node)),
('element_type', lambda node: np_data_type_to_destination_type(node.data_type)),
]

View File

@ -257,10 +257,12 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
'models. Model Optimizer performs necessary transformations to convert the shape to '
'the layout required by Inference Engine (N,C,H,W). The shape could contain '
'undefined dimensions (-1) and should fit the dimensions defined in the input '
'operation of the graph. If there are multiple inputs in the model, --input_shape '
'should contain definition of shape for each input separated by a comma, for '
'example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D shapes. '
'Alternatively, specify shapes with the --input option.')
'operation of the graph. Boundaries of undefined dimension can be specified with '
'ellipsis, for example [1,1..10,128,128]. One boundary can be undefined, for '
'example [1,..100] or [1,3,1..,1..]. If there are multiple inputs in the model, '
'--input_shape should contain definition of shape for each input separated by a '
'comma, for example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D '
'shapes. Alternatively, specify shapes with the --input option.')
common_group.add_argument('--scale', '-s',
type=float,
help='All input values coming from original network inputs will be ' +
@ -778,12 +780,12 @@ def remove_shape_from_input_value(input_value: str):
:return: string without shape specification
"""
assert '->' not in input_value, 'The function should not be called for input_value with constant value specified'
return re.sub(r'[(\[]([0-9 -]*)[)\]]', '', input_value)
return re.sub(r'[(\[]([0-9\.? -]*)[)\]]', '', input_value)
def get_shape_from_input_value(input_value: str):
"""
Returns the numpy array with shape corresponding to the shape specified in the input value string
Returns the list of tuples corresponding to the shape specified in the input value string
:param input_value: string passed as input to the --input command line parameter
:return: the corresponding shape and None if the shape is not specified in the input value
"""
@ -791,11 +793,11 @@ def get_shape_from_input_value(input_value: str):
input_value = input_value.split('->')[0]
# parse shape
shape = re.findall(r'[(\[]([0-9 -]*)[)\]]', input_value)
shape = re.findall(r'[(\[]([0-9\.\? -]+)[)\]]', input_value)
if len(shape) == 0:
shape = None
elif len(shape) == 1:
shape = np.fromstring(shape[0], dtype=np.int64, sep=' ')
shape = tuple(map(parse_dimension, shape[0].split(' ')))
else:
raise Error("Wrong syntax to specify shape. Use --input "
"\"node_name[shape]->value\"")
@ -854,6 +856,11 @@ def parse_input_value(input_value: str):
shape = get_shape_from_input_value(input_value.split('->')[0])
value_size = np.prod(len(value)) if isinstance(value, list) else 1
if value is not None and shape is not None:
for dim in shape:
if isinstance(dim, tuple) or dim == -1:
raise Error("Cannot freeze input with dynamic shape: {}".format(shape))
if shape is not None and value is not None and np.prod(shape) != value_size:
raise Error("The shape '{}' of the input node '{}' does not correspond to the number of elements '{}' in the "
"value: {}".format(shape, node_name, value_size, value))
@ -1043,6 +1050,33 @@ def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_
return placeholder_values, input_node_names
def parse_dimension(dim: str):
if '..' in dim:
numbers_reg = r'^[0-9]+$'
dims = dim.split('..')
match_res0 = re.match(numbers_reg, dims[0])
match_res1 = re.match(numbers_reg, dims[1])
if len(dims[0].strip()) > 0 and match_res0 is None:
Error("Incorrect min value of dimension '{}'".format(dims[0]))
if len(dims[1].strip()) > 0 and match_res1 is None:
Error("Incorrect max value of dimension '{}'".format(dims[1]))
min_val = np.int64(dims[0]) if match_res0 else np.int64(0)
max_val = np.int64(dims[1]) if match_res1 else np.iinfo(np.int64).max
assert min_val >= 0, "Incorrect min value of the dimension {}".format(dim)
if min_val == np.int64(0) and max_val == np.iinfo(np.int64).max:
return np.int64(-1)
assert min_val < max_val, "Min value should be less than max value. Got min value: {}, " \
"max value: {}".format(min_val, max_val)
return min_val, max_val
if '?' in dim:
return np.int64(-1)
return np.int64(dim)
def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None):
"""
Parses input layers names and input shapes from the cli and returns the parsed object.
@ -1055,16 +1089,18 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
E.g. 'inp1,inp2', 'node_name1[shape1]->value1,node_name2[shape2]->value2'
argv_input_shape
string with a list of input shapes: either an empty string, or tuples separated with comma.
E.g. '(1,2),(3,4)'.
Only positive integers are accepted except -1, which can be on any position in a shape.
E.g. '[1,2],[3,4]'.
Only positive integers are accepted.
'?' marks dynamic dimension.
Partial shape is specified with ellipsis. E.g. '[1..10,2,3]'
argv_batch
integer that overrides batch size in input shape
Returns
-------
parsed shapes in form of {'name of input':ndarray} if names of inputs are provided with shapes
parsed shapes in form of {'name of input':tuple} if names of inputs are provided with shapes
parsed shapes in form of {'name of input':None} if names of inputs are provided without shapes
ndarray if only one shape is provided and no input name
tuple if only one shape is provided and no input name
None if neither shape nor input were provided
"""
if argv_input_shape and argv_batch:
@ -1099,7 +1135,8 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
inputs = list()
placeholder_shapes = None
first_digit_reg = r'([0-9 ]+|-1)'
range_reg = r'([0-9]*\.\.[0-9]*)'
first_digit_reg = r'([0-9 ]+|-1|\?|{})'.format(range_reg)
next_digits_reg = r'(,{})*'.format(first_digit_reg)
tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg,
first_digit_reg, next_digits_reg)
@ -1107,7 +1144,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg)
if not re.match(full_reg, argv_input_shape):
raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape)
shapes = re.findall(r'[(\[]([0-9, -]+)[)\]]', argv_input_shape)
shapes = re.findall(r'[(\[]([0-9,\.\? -]+)[)\]]', argv_input_shape)
if argv_input:
inputs = argv_input.split(',')
@ -1118,14 +1155,22 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
if len(shapes) > 1:
raise Error('Please provide input layer names for input layer shapes. ' + refer_to_faq_msg(58))
else:
placeholder_shapes = np.fromstring(shapes[0], dtype=np.int64, sep=',')
placeholder_shapes = tuple(map(parse_dimension, shapes[0].split(',')))
# check if number of shapes does not match number of passed inputs
elif argv_input and (len(shapes) == len(inputs) or len(shapes) == 0):
# clean inputs from values for freezing
inputs = list(map(lambda x: x.split('->')[0], inputs))
placeholder_shapes = dict(zip_longest(inputs,
map(lambda x: np.fromstring(x, dtype=np.int64,
sep=',') if x else None, shapes)))
inputs_without_value = list(map(lambda x: x.split('->')[0], inputs))
placeholder_shapes = dict(zip_longest(inputs_without_value,
map(lambda x: tuple(map(parse_dimension, x.split(','))) if x else None,
shapes)))
for inp in inputs:
if '->' not in inp:
continue
shape = placeholder_shapes[inp.split('->')[0]]
for dim in shape:
if isinstance(dim, tuple) or dim == -1:
raise Error("Cannot freeze input with dynamic shape: {}".format(shape))
elif argv_input:
raise Error('Please provide each input layers with an input layer shape. ' + refer_to_faq_msg(58))

View File

@ -18,5 +18,7 @@ class Parameter_extender(Extender):
op.shape = int64_array([])
else:
Extender.attr_to_list(op, 'shape')
if -1 in op.shape:
op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape])
for i, dim in enumerate(op.shape):
if dim == -1 or (isinstance(dim, str) and ".." in dim):
op.shape[i] = -1
op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape])

View File

@ -537,11 +537,11 @@ class TestUserDataRepack(unittest.TestCase):
def test_error(self):
graph = build_graph(self.nodes, self.edges)
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None)
def test_error_2(self):
graph = build_graph(self.nodes, self.edges)
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None)
def test_error_3(self):
graph = build_graph(self.nodes, self.edges)
@ -549,7 +549,7 @@ class TestUserDataRepack(unittest.TestCase):
def test_input_and_freeze(self):
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
shape_1 = tuple([1, 160, 160, 3])
input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True})
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]})
self.assertDictEqual(freeze_placeholder, {'B': True})

View File

@ -778,6 +778,153 @@ class TestShapesParsing(unittest.TestCase):
input_shapes = "(12,4,1),(4,-6,8)"
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
def test_get_shapes_several_inputs_several_partial_shapes(self):
argv_input = "inp1,inp2"
input_shapes = "(1,..22,1..100,?), (-1,45..,7,1)"
result, _ = get_placeholder_shapes(argv_input, input_shapes)
exp_res = {'inp1': (1, (0, 22), (1, 100), -1), 'inp2': (-1, (45, np.iinfo(np.int64).max), 7, 1)}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
def test_get_shapes_several_inputs_several_partial_shapes2(self):
# shapes specified using --input command line parameter and no values
argv_input = "inp1[1 ? 50..100 123],inp2[-1 45.. ..7 1]"
result, _ = get_placeholder_shapes(argv_input, None)
exp_res = {'inp1': (1, -1, (50, 100), 123), 'inp2': (-1, (45,np.iinfo(np.int64).max), (0, 7), 1)}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {}
input_node_names_ref = "inp1,inp2"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys():
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_several_inputs_several_partial_shapes3(self):
# shapes and value for freezing specified using --input command line parameter
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[5]->[1.0 1.0 2.0 3.0 5.0]"
result, _ = get_placeholder_shapes(argv_input, None)
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys():
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_several_inputs_several_partial_shapes4(self):
# shapes specified using --input_shape and values for freezing using --input command line parameter
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
result, _ = get_placeholder_shapes(argv_input, input_shapes)
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys():
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
self.assertEqual(input_node_names_ref, input_node_names_res)
def test_get_shapes_several_inputs_several_partial_shapes5(self):
# some values for freezing specified using --freeze_placeholder_with_value
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
argv_freeze_placeholder_with_value = "inp2->[5.0 7.0 3.0],inp4->[100.0 200.0]"
result, _ = get_placeholder_shapes(argv_input, input_shapes)
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
for i in placeholder_values_ref.keys():
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
self.assertEqual(input_node_names_ref, input_node_names_res)
def test_get_shapes_several_inputs_several_partial_shapes6(self):
# 0D value for freezing specified using --input command line parameter without shape
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3->False"
result, _ = get_placeholder_shapes(argv_input, None)
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False}
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys():
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_several_inputs_several_partial_shapes7(self):
# 0D shape and value for freezing specified using --input command line parameter
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[]->True"
result, _ = get_placeholder_shapes(argv_input, None)
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape}
self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys():
npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True}
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys():
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_and_data_types_partial_shape_with_input_port(self):
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],0:inp2[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
ref_result_shapes = {'inp1:1': np.array([3, 1]), '0:inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
ref_result_data_types = {'0:inp2': np.int32, 'inp3:4': np.float32}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys():
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
def test_get_shapes_and_data_types_partial_shape_with_output_port(self):
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],inp2:3[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
ref_result_shapes = {'inp1:1': np.array([3, 1]), 'inp2:3': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
ref_result_data_types = {'inp2:3': np.int32, 'inp3:4': np.float32}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys():
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
def test_partial_shapes_negative_case(self):
argv_input = "inp1"
input_shapes = "[6754fg..23ed]"
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
def test_partial_shapes_freeze_dynamic_negative_case1(self):
argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]"
self.assertRaises(Error, get_placeholder_shapes, argv_input, "")
def test_partial_shapes_freeze_dynamic_negative_case2(self):
argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]"
self.assertRaises(Error, get_placeholder_shapes, argv_input, "")
def test_partial_shapes_freeze_dynamic_negative_case3(self):
# some values for freezing specified using --freeze_placeholder_with_value
argv_input = "inp1->[1.0 2.0 3.0]"
input_shapes = "[3,1..10]"
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
class TestModelNameParsing(unittest.TestCase):
def test_model_name_ideal(self):