Partial shape serialization MO (#8927)
* Support of partial shape in MO. * Added partial shapes serialization in MO. * Removed breckets. * Fixed parameter extender in IR reader. * Removed wrong changes. * Added checks, added tests. * Updated help. * Fixed import.
This commit is contained in:
parent
eab49eec8b
commit
3cb27715e2
@ -605,7 +605,7 @@ def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, n
|
|||||||
if ph_id not in _input_shapes]
|
if ph_id not in _input_shapes]
|
||||||
else:
|
else:
|
||||||
# np.ndarray is a shape. User provided only --input_shape key
|
# np.ndarray is a shape. User provided only --input_shape key
|
||||||
assert isinstance(input_user_shapes, np.ndarray)
|
assert isinstance(input_user_shapes, tuple)
|
||||||
if len(placeholders_ids) == 1:
|
if len(placeholders_ids) == 1:
|
||||||
# There is only one placeholder in the original network
|
# There is only one placeholder in the original network
|
||||||
_input_shapes[placeholders_ids[0]].append({'shape': input_user_shapes, 'port': None})
|
_input_shapes[placeholders_ids[0]].append({'shape': input_user_shapes, 'port': None})
|
||||||
@ -861,7 +861,7 @@ def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, por
|
|||||||
|
|
||||||
|
|
||||||
def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
||||||
shape=None, data_type=None, is_out_port: bool = False):
|
shape=None, user_shape=None, data_type=None, is_out_port: bool = False):
|
||||||
"""
|
"""
|
||||||
This function adds Input node to node with id==node_id to specified port (in or out defined with is_out_port).
|
This function adds Input node to node with id==node_id to specified port (in or out defined with is_out_port).
|
||||||
:param graph: graph to operate on.
|
:param graph: graph to operate on.
|
||||||
@ -869,6 +869,8 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
|||||||
:param port: number of port of node_id node for adding input node.
|
:param port: number of port of node_id node for adding input node.
|
||||||
:param data: flag that define whether data nodes is needed or not.
|
:param data: flag that define whether data nodes is needed or not.
|
||||||
:param shape: shape for new input node.
|
:param shape: shape for new input node.
|
||||||
|
:param user_shape: shape provided by user which may contain boundaries of dynamic dimension.
|
||||||
|
:param data_type: data type of input node.
|
||||||
:param is_out_port: flag that define whether port is output port or not.
|
:param is_out_port: flag that define whether port is output port or not.
|
||||||
:return: id of new Input operation
|
:return: id of new Input operation
|
||||||
"""
|
"""
|
||||||
@ -876,7 +878,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
|||||||
from openvino.tools.mo.ops.parameter import Parameter
|
from openvino.tools.mo.ops.parameter import Parameter
|
||||||
if data_type is None:
|
if data_type is None:
|
||||||
data_type = np.float32
|
data_type = np.float32
|
||||||
input_op = Parameter(graph, dict(shape=shape, data_type=data_type, initial_node_name=node_id,
|
input_op = Parameter(graph, dict(shape=shape, user_shape=user_shape, data_type=data_type, initial_node_name=node_id,
|
||||||
name=get_new_placeholder_name(node_id, is_out_port, port)))
|
name=get_new_placeholder_name(node_id, is_out_port, port)))
|
||||||
|
|
||||||
fw_name = Node(graph, node_id).soft_get('name')
|
fw_name = Node(graph, node_id).soft_get('name')
|
||||||
@ -901,7 +903,7 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
|||||||
|
|
||||||
|
|
||||||
def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node, port: int, node_id: str,
|
def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node, port: int, node_id: str,
|
||||||
shape: np.array, data_type,
|
shape: np.array, user_shape: tuple, data_type,
|
||||||
inputs: list, edges_to_remove: list):
|
inputs: list, edges_to_remove: list):
|
||||||
n_inputs = len(smart_node.in_nodes())
|
n_inputs = len(smart_node.in_nodes())
|
||||||
if n_inputs > 1 and port is None:
|
if n_inputs > 1 and port is None:
|
||||||
@ -912,7 +914,7 @@ def add_input_ops_helper_before_infer_input_port(graph: Graph, smart_node: Node,
|
|||||||
port = port if port is not None else 0
|
port = port if port is not None else 0
|
||||||
edges_to_remove.append((smart_node.in_node(port).id, smart_node.id))
|
edges_to_remove.append((smart_node.in_node(port).id, smart_node.id))
|
||||||
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
|
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
|
||||||
shape=shape, data_type=data_type))
|
shape=shape, user_shape=user_shape, data_type=data_type))
|
||||||
|
|
||||||
|
|
||||||
def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node, port:int, node_id: str,
|
def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node, port:int, node_id: str,
|
||||||
@ -934,13 +936,14 @@ def add_input_ops_helper_after_infer_input_port(graph: Graph, smart_node: Node,
|
|||||||
edges_to_remove.append((in_node.id, node_id))
|
edges_to_remove.append((in_node.id, node_id))
|
||||||
|
|
||||||
|
|
||||||
def add_input_ops_helper_before_infer_output_port(graph: Graph, port:int, node_id: str,
|
def add_input_ops_helper_before_infer_output_port(graph: Graph, port: int, node_id: str,
|
||||||
shape: np.array, data_type, inputs: list, edges_to_remove: list):
|
shape: np.array, user_shape: tuple, data_type: tuple,
|
||||||
|
inputs: list, edges_to_remove: list):
|
||||||
for u, v, edge_attrs in graph.out_edges(node_id, data=True):
|
for u, v, edge_attrs in graph.out_edges(node_id, data=True):
|
||||||
if edge_attrs['out'] == port:
|
if edge_attrs['out'] == port:
|
||||||
edges_to_remove.append((u, v)) # we need to remove all edges from this port
|
edges_to_remove.append((u, v)) # we need to remove all edges from this port
|
||||||
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
|
inputs.append(add_input_op(graph=graph, node_id=node_id, port=port, data=False,
|
||||||
shape=shape, data_type=data_type, is_out_port=True))
|
shape=shape, user_shape=user_shape, data_type=data_type, is_out_port=True))
|
||||||
|
|
||||||
|
|
||||||
def add_input_ops_helper_after_infer_output_port(graph: Graph, smart_node: Node, port:int, node_id: str,
|
def add_input_ops_helper_after_infer_output_port(graph: Graph, smart_node: Node, port:int, node_id: str,
|
||||||
@ -987,8 +990,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
|
|||||||
|
|
||||||
is_out_port = 'out' in port_and_shape_info # by default we assume input port or input node without port
|
is_out_port = 'out' in port_and_shape_info # by default we assume input port or input node without port
|
||||||
shape = port_and_shape_info['shape'] if 'shape' in port_and_shape_info else None
|
shape = port_and_shape_info['shape'] if 'shape' in port_and_shape_info else None
|
||||||
|
user_shape = None
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
shape = shape_array([dim if dim >= 0 else dynamic_dimension_value for dim in shape])
|
user_shape = shape
|
||||||
|
shape = shape_array(
|
||||||
|
[dim if type(dim) != tuple and dim >= 0 else dynamic_dimension_value for dim in shape])
|
||||||
data_type = port_and_shape_info['data_type'] if 'data_type' in port_and_shape_info else None
|
data_type = port_and_shape_info['data_type'] if 'data_type' in port_and_shape_info else None
|
||||||
smart_node = Node(graph, node_id)
|
smart_node = Node(graph, node_id)
|
||||||
|
|
||||||
@ -1016,6 +1022,7 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
|
|||||||
refer_to_faq_msg(28), node_id, port)
|
refer_to_faq_msg(28), node_id, port)
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
smart_node['shape'] = shape
|
smart_node['shape'] = shape
|
||||||
|
smart_node['user_shape'] = user_shape
|
||||||
if data_type is not None:
|
if data_type is not None:
|
||||||
smart_node['data_type'] = data_type
|
smart_node['data_type'] = data_type
|
||||||
inputs.append(node_id)
|
inputs.append(node_id)
|
||||||
@ -1027,10 +1034,11 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
|
|||||||
continue
|
continue
|
||||||
# We cut with shapes provided by user and there is no need to wait till infer
|
# We cut with shapes provided by user and there is no need to wait till infer
|
||||||
if is_out_port:
|
if is_out_port:
|
||||||
add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, data_type, inputs,
|
add_input_ops_helper_before_infer_output_port(graph, port, node_id, shape, user_shape,
|
||||||
edges_to_remove)
|
data_type, inputs, edges_to_remove)
|
||||||
else:
|
else:
|
||||||
add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape, data_type, inputs,
|
add_input_ops_helper_before_infer_input_port(graph, smart_node, port, node_id, shape,
|
||||||
|
user_shape, data_type, inputs,
|
||||||
edges_to_remove)
|
edges_to_remove)
|
||||||
else:
|
else:
|
||||||
# We cut after infer and we need inferred shapes in nodes
|
# We cut after infer and we need inferred shapes in nodes
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import logging as log
|
import logging as log
|
||||||
|
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, unmask_shape, shape_array, dynamic_dimension_value
|
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, shape_array, \
|
||||||
|
dynamic_dimension_value
|
||||||
from openvino.tools.mo.graph.graph import Graph
|
from openvino.tools.mo.graph.graph import Graph
|
||||||
from openvino.tools.mo.middle.passes.infer import partial_infer
|
from openvino.tools.mo.middle.passes.infer import partial_infer
|
||||||
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
from openvino.tools.mo.middle.replacement import MiddleReplacementPattern
|
||||||
|
from openvino.tools.mo.ops.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
class PartialInfer(MiddleReplacementPattern):
|
class PartialInfer(MiddleReplacementPattern):
|
||||||
@ -25,7 +27,7 @@ class PartialInfer(MiddleReplacementPattern):
|
|||||||
param_shape = parameter.soft_get('shape', shape_array(dynamic_dimension_value))
|
param_shape = parameter.soft_get('shape', shape_array(dynamic_dimension_value))
|
||||||
if not is_fully_defined(param_shape):
|
if not is_fully_defined(param_shape):
|
||||||
parameter_name = parameter.soft_get('name', parameter.id)
|
parameter_name = parameter.soft_get('name', parameter.id)
|
||||||
dynamic_inputs[parameter_name] = param_shape
|
dynamic_inputs[parameter_name] = parameter
|
||||||
if dynamic_inputs:
|
if dynamic_inputs:
|
||||||
log.error('The model contains input(s) with partially defined shapes: {}. '
|
log.error('The model contains input(s) with partially defined shapes: {}. '
|
||||||
'Starting from the 2022.1 release the Model Optimizer can generate an IR with partially defined '
|
'Starting from the 2022.1 release the Model Optimizer can generate an IR with partially defined '
|
||||||
@ -34,7 +36,7 @@ class PartialInfer(MiddleReplacementPattern):
|
|||||||
'call "reshape" method in the Inference Engine and specify static input shapes. For optimal '
|
'call "reshape" method in the Inference Engine and specify static input shapes. For optimal '
|
||||||
'performance, it is still recommended to update input shapes with fixed ones using "--input" or '
|
'performance, it is still recommended to update input shapes with fixed ones using "--input" or '
|
||||||
'"--input_shape" command-line parameters.'
|
'"--input_shape" command-line parameters.'
|
||||||
.format(','.join('name="{}" shape="{}"'.format(name, unmask_shape(shape))
|
.format(','.join('name="{}" shape="{}"'.format(name, Parameter.shape_serialize(parameter))
|
||||||
for name, shape in dynamic_inputs.items())),
|
for name, parameter in dynamic_inputs.items())),
|
||||||
extra={'is_warning': True})
|
extra={'is_warning': True})
|
||||||
partial_infer(graph)
|
partial_infer(graph)
|
||||||
|
@ -118,7 +118,7 @@ def fe_input_user_data_repack(input_model: InputModel, input_user_shapes: [None,
|
|||||||
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
|
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
|
||||||
else:
|
else:
|
||||||
_input_shapes.append({'node': node, 'shape': shape})
|
_input_shapes.append({'node': node, 'shape': shape})
|
||||||
elif isinstance(input_user_shapes, np.ndarray):
|
elif isinstance(input_user_shapes, tuple):
|
||||||
model_inputs = input_model.get_inputs()
|
model_inputs = input_model.get_inputs()
|
||||||
assert len(model_inputs) == 1
|
assert len(model_inputs) == 1
|
||||||
_input_shapes.append({'node': model_inputs[0], 'shape': input_user_shapes})
|
_input_shapes.append({'node': model_inputs[0], 'shape': input_user_shapes})
|
||||||
|
@ -12,6 +12,8 @@ from openvino.runtime import Dimension, PartialShape # pylint: disable=no
|
|||||||
from openvino.frontend import FrontEnd, Place # pylint: disable=no-name-in-module,import-error
|
from openvino.frontend import FrontEnd, Place # pylint: disable=no-name-in-module,import-error
|
||||||
from openvino.runtime.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error
|
from openvino.runtime.utils.types import get_element_type # pylint: disable=no-name-in-module,import-error
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||||
"""
|
"""
|
||||||
@ -66,7 +68,7 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
|||||||
for user_shape in user_shapes:
|
for user_shape in user_shapes:
|
||||||
if user_shape.get('shape') is not None:
|
if user_shape.get('shape') is not None:
|
||||||
input_model.set_partial_shape(
|
input_model.set_partial_shape(
|
||||||
user_shape['node'], PartialShape(user_shape['shape']))
|
user_shape['node'], partial_shape_from_tuple(user_shape['shape']))
|
||||||
if user_shape.get('data_type') is not None:
|
if user_shape.get('data_type') is not None:
|
||||||
data_type = get_element_type(user_shape['data_type'])
|
data_type = get_element_type(user_shape['data_type'])
|
||||||
log.debug('Set data type: {}'.format(data_type))
|
log.debug('Set data type: {}'.format(data_type))
|
||||||
@ -97,3 +99,16 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
|||||||
|
|
||||||
ngraph_function = moc_front_end.convert(input_model)
|
ngraph_function = moc_front_end.convert(input_model)
|
||||||
return ngraph_function
|
return ngraph_function
|
||||||
|
|
||||||
|
|
||||||
|
def partial_shape_from_tuple(shape: tuple):
|
||||||
|
new_shape = []
|
||||||
|
for dim in shape:
|
||||||
|
if isinstance(dim, tuple):
|
||||||
|
assert len(dim) == 2, "Incorrect boundaries of dimension {} in shape {}".format(dim, shape)
|
||||||
|
assert dim[0] >= 0, "Incorrect min value of dimension {} in shape".format(dim, shape)
|
||||||
|
new_shape.append(Dimension(dim[0], dim[1]))
|
||||||
|
else:
|
||||||
|
assert isinstance(dim, np.int64), "Incorrect type of dimension {} in shape".format(dim, shape)
|
||||||
|
new_shape.append(Dimension(dim))
|
||||||
|
return PartialShape(new_shape)
|
||||||
|
@ -26,6 +26,7 @@ class Parameter(Op):
|
|||||||
'type_infer': self.type_infer,
|
'type_infer': self.type_infer,
|
||||||
|
|
||||||
'out_ports_count': 1,
|
'out_ports_count': 1,
|
||||||
|
'user_shape': None,
|
||||||
}
|
}
|
||||||
if 'data_type' not in attrs:
|
if 'data_type' not in attrs:
|
||||||
mandatory_props['data_type'] = np.float32
|
mandatory_props['data_type'] = np.float32
|
||||||
@ -35,9 +36,25 @@ class Parameter(Op):
|
|||||||
def type_infer(node):
|
def type_infer(node):
|
||||||
node.out_port(0).set_data_type(node.data_type)
|
node.out_port(0).set_data_type(node.data_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shape_serialize(node):
|
||||||
|
def serialize_dimension(dim: [tuple, np.int64]):
|
||||||
|
if type(dim) == tuple:
|
||||||
|
assert len(dim) == 2, "Unable to serialize shape {} in node {}".format(node.soft_get('user_shape'),
|
||||||
|
node.soft_get('name', node.id))
|
||||||
|
min_str = str(dim[0]) if dim[0] > 0 else ""
|
||||||
|
max_str = str(dim[1]) if dim[1] < np.iinfo(np.int64).max else ""
|
||||||
|
return min_str + ".." + max_str
|
||||||
|
return str(dim)
|
||||||
|
|
||||||
|
if not node.has_valid('user_shape'):
|
||||||
|
return ','.join([str(i) for i in unmask_shape(node.shape)])
|
||||||
|
shape = node.soft_get('user_shape')
|
||||||
|
return ','.join(map(serialize_dimension, shape))
|
||||||
|
|
||||||
def supported_attrs(self):
|
def supported_attrs(self):
|
||||||
return [
|
return [
|
||||||
('shape', lambda node: ','.join([str(i) for i in unmask_shape(node.shape)])),
|
('shape', lambda node: self.shape_serialize(node)),
|
||||||
('element_type', lambda node: np_data_type_to_destination_type(node.data_type)),
|
('element_type', lambda node: np_data_type_to_destination_type(node.data_type)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -257,10 +257,12 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
|||||||
'models. Model Optimizer performs necessary transformations to convert the shape to '
|
'models. Model Optimizer performs necessary transformations to convert the shape to '
|
||||||
'the layout required by Inference Engine (N,C,H,W). The shape could contain '
|
'the layout required by Inference Engine (N,C,H,W). The shape could contain '
|
||||||
'undefined dimensions (-1) and should fit the dimensions defined in the input '
|
'undefined dimensions (-1) and should fit the dimensions defined in the input '
|
||||||
'operation of the graph. If there are multiple inputs in the model, --input_shape '
|
'operation of the graph. Boundaries of undefined dimension can be specified with '
|
||||||
'should contain definition of shape for each input separated by a comma, for '
|
'ellipsis, for example [1,1..10,128,128]. One boundary can be undefined, for '
|
||||||
'example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D shapes. '
|
'example [1,..100] or [1,3,1..,1..]. If there are multiple inputs in the model, '
|
||||||
'Alternatively, specify shapes with the --input option.')
|
'--input_shape should contain definition of shape for each input separated by a '
|
||||||
|
'comma, for example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D '
|
||||||
|
'shapes. Alternatively, specify shapes with the --input option.')
|
||||||
common_group.add_argument('--scale', '-s',
|
common_group.add_argument('--scale', '-s',
|
||||||
type=float,
|
type=float,
|
||||||
help='All input values coming from original network inputs will be ' +
|
help='All input values coming from original network inputs will be ' +
|
||||||
@ -778,12 +780,12 @@ def remove_shape_from_input_value(input_value: str):
|
|||||||
:return: string without shape specification
|
:return: string without shape specification
|
||||||
"""
|
"""
|
||||||
assert '->' not in input_value, 'The function should not be called for input_value with constant value specified'
|
assert '->' not in input_value, 'The function should not be called for input_value with constant value specified'
|
||||||
return re.sub(r'[(\[]([0-9 -]*)[)\]]', '', input_value)
|
return re.sub(r'[(\[]([0-9\.? -]*)[)\]]', '', input_value)
|
||||||
|
|
||||||
|
|
||||||
def get_shape_from_input_value(input_value: str):
|
def get_shape_from_input_value(input_value: str):
|
||||||
"""
|
"""
|
||||||
Returns the numpy array with shape corresponding to the shape specified in the input value string
|
Returns the list of tuples corresponding to the shape specified in the input value string
|
||||||
:param input_value: string passed as input to the --input command line parameter
|
:param input_value: string passed as input to the --input command line parameter
|
||||||
:return: the corresponding shape and None if the shape is not specified in the input value
|
:return: the corresponding shape and None if the shape is not specified in the input value
|
||||||
"""
|
"""
|
||||||
@ -791,11 +793,11 @@ def get_shape_from_input_value(input_value: str):
|
|||||||
input_value = input_value.split('->')[0]
|
input_value = input_value.split('->')[0]
|
||||||
|
|
||||||
# parse shape
|
# parse shape
|
||||||
shape = re.findall(r'[(\[]([0-9 -]*)[)\]]', input_value)
|
shape = re.findall(r'[(\[]([0-9\.\? -]+)[)\]]', input_value)
|
||||||
if len(shape) == 0:
|
if len(shape) == 0:
|
||||||
shape = None
|
shape = None
|
||||||
elif len(shape) == 1:
|
elif len(shape) == 1:
|
||||||
shape = np.fromstring(shape[0], dtype=np.int64, sep=' ')
|
shape = tuple(map(parse_dimension, shape[0].split(' ')))
|
||||||
else:
|
else:
|
||||||
raise Error("Wrong syntax to specify shape. Use --input "
|
raise Error("Wrong syntax to specify shape. Use --input "
|
||||||
"\"node_name[shape]->value\"")
|
"\"node_name[shape]->value\"")
|
||||||
@ -854,6 +856,11 @@ def parse_input_value(input_value: str):
|
|||||||
shape = get_shape_from_input_value(input_value.split('->')[0])
|
shape = get_shape_from_input_value(input_value.split('->')[0])
|
||||||
value_size = np.prod(len(value)) if isinstance(value, list) else 1
|
value_size = np.prod(len(value)) if isinstance(value, list) else 1
|
||||||
|
|
||||||
|
if value is not None and shape is not None:
|
||||||
|
for dim in shape:
|
||||||
|
if isinstance(dim, tuple) or dim == -1:
|
||||||
|
raise Error("Cannot freeze input with dynamic shape: {}".format(shape))
|
||||||
|
|
||||||
if shape is not None and value is not None and np.prod(shape) != value_size:
|
if shape is not None and value is not None and np.prod(shape) != value_size:
|
||||||
raise Error("The shape '{}' of the input node '{}' does not correspond to the number of elements '{}' in the "
|
raise Error("The shape '{}' of the input node '{}' does not correspond to the number of elements '{}' in the "
|
||||||
"value: {}".format(shape, node_name, value_size, value))
|
"value: {}".format(shape, node_name, value_size, value))
|
||||||
@ -1043,6 +1050,33 @@ def get_freeze_placeholder_values(argv_input: str, argv_freeze_placeholder_with_
|
|||||||
return placeholder_values, input_node_names
|
return placeholder_values, input_node_names
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dimension(dim: str):
|
||||||
|
if '..' in dim:
|
||||||
|
numbers_reg = r'^[0-9]+$'
|
||||||
|
dims = dim.split('..')
|
||||||
|
match_res0 = re.match(numbers_reg, dims[0])
|
||||||
|
match_res1 = re.match(numbers_reg, dims[1])
|
||||||
|
if len(dims[0].strip()) > 0 and match_res0 is None:
|
||||||
|
Error("Incorrect min value of dimension '{}'".format(dims[0]))
|
||||||
|
if len(dims[1].strip()) > 0 and match_res1 is None:
|
||||||
|
Error("Incorrect max value of dimension '{}'".format(dims[1]))
|
||||||
|
|
||||||
|
min_val = np.int64(dims[0]) if match_res0 else np.int64(0)
|
||||||
|
max_val = np.int64(dims[1]) if match_res1 else np.iinfo(np.int64).max
|
||||||
|
assert min_val >= 0, "Incorrect min value of the dimension {}".format(dim)
|
||||||
|
|
||||||
|
if min_val == np.int64(0) and max_val == np.iinfo(np.int64).max:
|
||||||
|
return np.int64(-1)
|
||||||
|
|
||||||
|
assert min_val < max_val, "Min value should be less than max value. Got min value: {}, " \
|
||||||
|
"max value: {}".format(min_val, max_val)
|
||||||
|
|
||||||
|
return min_val, max_val
|
||||||
|
if '?' in dim:
|
||||||
|
return np.int64(-1)
|
||||||
|
return np.int64(dim)
|
||||||
|
|
||||||
|
|
||||||
def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None):
|
def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None):
|
||||||
"""
|
"""
|
||||||
Parses input layers names and input shapes from the cli and returns the parsed object.
|
Parses input layers names and input shapes from the cli and returns the parsed object.
|
||||||
@ -1055,16 +1089,18 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
|||||||
E.g. 'inp1,inp2', 'node_name1[shape1]->value1,node_name2[shape2]->value2'
|
E.g. 'inp1,inp2', 'node_name1[shape1]->value1,node_name2[shape2]->value2'
|
||||||
argv_input_shape
|
argv_input_shape
|
||||||
string with a list of input shapes: either an empty string, or tuples separated with comma.
|
string with a list of input shapes: either an empty string, or tuples separated with comma.
|
||||||
E.g. '(1,2),(3,4)'.
|
E.g. '[1,2],[3,4]'.
|
||||||
Only positive integers are accepted except -1, which can be on any position in a shape.
|
Only positive integers are accepted.
|
||||||
|
'?' marks dynamic dimension.
|
||||||
|
Partial shape is specified with ellipsis. E.g. '[1..10,2,3]'
|
||||||
argv_batch
|
argv_batch
|
||||||
integer that overrides batch size in input shape
|
integer that overrides batch size in input shape
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
parsed shapes in form of {'name of input':ndarray} if names of inputs are provided with shapes
|
parsed shapes in form of {'name of input':tuple} if names of inputs are provided with shapes
|
||||||
parsed shapes in form of {'name of input':None} if names of inputs are provided without shapes
|
parsed shapes in form of {'name of input':None} if names of inputs are provided without shapes
|
||||||
ndarray if only one shape is provided and no input name
|
tuple if only one shape is provided and no input name
|
||||||
None if neither shape nor input were provided
|
None if neither shape nor input were provided
|
||||||
"""
|
"""
|
||||||
if argv_input_shape and argv_batch:
|
if argv_input_shape and argv_batch:
|
||||||
@ -1099,7 +1135,8 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
|||||||
inputs = list()
|
inputs = list()
|
||||||
placeholder_shapes = None
|
placeholder_shapes = None
|
||||||
|
|
||||||
first_digit_reg = r'([0-9 ]+|-1)'
|
range_reg = r'([0-9]*\.\.[0-9]*)'
|
||||||
|
first_digit_reg = r'([0-9 ]+|-1|\?|{})'.format(range_reg)
|
||||||
next_digits_reg = r'(,{})*'.format(first_digit_reg)
|
next_digits_reg = r'(,{})*'.format(first_digit_reg)
|
||||||
tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg,
|
tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg,
|
||||||
first_digit_reg, next_digits_reg)
|
first_digit_reg, next_digits_reg)
|
||||||
@ -1107,7 +1144,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
|||||||
full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg)
|
full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg)
|
||||||
if not re.match(full_reg, argv_input_shape):
|
if not re.match(full_reg, argv_input_shape):
|
||||||
raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape)
|
raise Error('Input shape "{}" cannot be parsed. ' + refer_to_faq_msg(57), argv_input_shape)
|
||||||
shapes = re.findall(r'[(\[]([0-9, -]+)[)\]]', argv_input_shape)
|
shapes = re.findall(r'[(\[]([0-9,\.\? -]+)[)\]]', argv_input_shape)
|
||||||
|
|
||||||
if argv_input:
|
if argv_input:
|
||||||
inputs = argv_input.split(',')
|
inputs = argv_input.split(',')
|
||||||
@ -1118,14 +1155,22 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
|||||||
if len(shapes) > 1:
|
if len(shapes) > 1:
|
||||||
raise Error('Please provide input layer names for input layer shapes. ' + refer_to_faq_msg(58))
|
raise Error('Please provide input layer names for input layer shapes. ' + refer_to_faq_msg(58))
|
||||||
else:
|
else:
|
||||||
placeholder_shapes = np.fromstring(shapes[0], dtype=np.int64, sep=',')
|
placeholder_shapes = tuple(map(parse_dimension, shapes[0].split(',')))
|
||||||
# check if number of shapes does not match number of passed inputs
|
# check if number of shapes does not match number of passed inputs
|
||||||
elif argv_input and (len(shapes) == len(inputs) or len(shapes) == 0):
|
elif argv_input and (len(shapes) == len(inputs) or len(shapes) == 0):
|
||||||
# clean inputs from values for freezing
|
# clean inputs from values for freezing
|
||||||
inputs = list(map(lambda x: x.split('->')[0], inputs))
|
inputs_without_value = list(map(lambda x: x.split('->')[0], inputs))
|
||||||
placeholder_shapes = dict(zip_longest(inputs,
|
placeholder_shapes = dict(zip_longest(inputs_without_value,
|
||||||
map(lambda x: np.fromstring(x, dtype=np.int64,
|
map(lambda x: tuple(map(parse_dimension, x.split(','))) if x else None,
|
||||||
sep=',') if x else None, shapes)))
|
shapes)))
|
||||||
|
for inp in inputs:
|
||||||
|
if '->' not in inp:
|
||||||
|
continue
|
||||||
|
shape = placeholder_shapes[inp.split('->')[0]]
|
||||||
|
for dim in shape:
|
||||||
|
if isinstance(dim, tuple) or dim == -1:
|
||||||
|
raise Error("Cannot freeze input with dynamic shape: {}".format(shape))
|
||||||
|
|
||||||
elif argv_input:
|
elif argv_input:
|
||||||
raise Error('Please provide each input layers with an input layer shape. ' + refer_to_faq_msg(58))
|
raise Error('Please provide each input layers with an input layer shape. ' + refer_to_faq_msg(58))
|
||||||
|
|
||||||
|
@ -18,5 +18,7 @@ class Parameter_extender(Extender):
|
|||||||
op.shape = int64_array([])
|
op.shape = int64_array([])
|
||||||
else:
|
else:
|
||||||
Extender.attr_to_list(op, 'shape')
|
Extender.attr_to_list(op, 'shape')
|
||||||
if -1 in op.shape:
|
for i, dim in enumerate(op.shape):
|
||||||
op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape])
|
if dim == -1 or (isinstance(dim, str) and ".." in dim):
|
||||||
|
op.shape[i] = -1
|
||||||
|
op.shape = shape_array([d if d != -1 else dynamic_dimension_value for d in op.shape])
|
||||||
|
@ -537,11 +537,11 @@ class TestUserDataRepack(unittest.TestCase):
|
|||||||
|
|
||||||
def test_error(self):
|
def test_error(self):
|
||||||
graph = build_graph(self.nodes, self.edges)
|
graph = build_graph(self.nodes, self.edges)
|
||||||
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
|
self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None)
|
||||||
|
|
||||||
def test_error_2(self):
|
def test_error_2(self):
|
||||||
graph = build_graph(self.nodes, self.edges)
|
graph = build_graph(self.nodes, self.edges)
|
||||||
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
|
self.assertRaises(Error, input_user_data_repack, graph, tuple([1, 227, 227, 3]), None)
|
||||||
|
|
||||||
def test_error_3(self):
|
def test_error_3(self):
|
||||||
graph = build_graph(self.nodes, self.edges)
|
graph = build_graph(self.nodes, self.edges)
|
||||||
@ -549,7 +549,7 @@ class TestUserDataRepack(unittest.TestCase):
|
|||||||
|
|
||||||
def test_input_and_freeze(self):
|
def test_input_and_freeze(self):
|
||||||
graph = build_graph(self.nodes, self.edges)
|
graph = build_graph(self.nodes, self.edges)
|
||||||
shape_1 = np.array([1, 160, 160, 3])
|
shape_1 = tuple([1, 160, 160, 3])
|
||||||
input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True})
|
input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True})
|
||||||
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]})
|
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]})
|
||||||
self.assertDictEqual(freeze_placeholder, {'B': True})
|
self.assertDictEqual(freeze_placeholder, {'B': True})
|
||||||
|
@ -778,6 +778,153 @@ class TestShapesParsing(unittest.TestCase):
|
|||||||
input_shapes = "(12,4,1),(4,-6,8)"
|
input_shapes = "(12,4,1),(4,-6,8)"
|
||||||
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||||
|
|
||||||
|
def test_get_shapes_several_inputs_several_partial_shapes(self):
|
||||||
|
argv_input = "inp1,inp2"
|
||||||
|
input_shapes = "(1,..22,1..100,?), (-1,45..,7,1)"
|
||||||
|
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||||
|
exp_res = {'inp1': (1, (0, 22), (1, 100), -1), 'inp2': (-1, (45, np.iinfo(np.int64).max), 7, 1)}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
|
||||||
|
def test_get_shapes_several_inputs_several_partial_shapes2(self):
|
||||||
|
# shapes specified using --input command line parameter and no values
|
||||||
|
argv_input = "inp1[1 ? 50..100 123],inp2[-1 45.. ..7 1]"
|
||||||
|
result, _ = get_placeholder_shapes(argv_input, None)
|
||||||
|
exp_res = {'inp1': (1, -1, (50, 100), 123), 'inp2': (-1, (45,np.iinfo(np.int64).max), (0, 7), 1)}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||||
|
placeholder_values_ref = {}
|
||||||
|
input_node_names_ref = "inp1,inp2"
|
||||||
|
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||||
|
for i in placeholder_values_ref.keys():
|
||||||
|
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||||
|
|
||||||
|
def test_get_shapes_several_inputs_several_partial_shapes3(self):
|
||||||
|
# shapes and value for freezing specified using --input command line parameter
|
||||||
|
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[5]->[1.0 1.0 2.0 3.0 5.0]"
|
||||||
|
result, _ = get_placeholder_shapes(argv_input, None)
|
||||||
|
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||||
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||||
|
input_node_names_ref = "inp1,inp2,inp3"
|
||||||
|
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||||
|
for i in placeholder_values_ref.keys():
|
||||||
|
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||||
|
|
||||||
|
def test_get_shapes_several_inputs_several_partial_shapes4(self):
|
||||||
|
# shapes specified using --input_shape and values for freezing using --input command line parameter
|
||||||
|
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
|
||||||
|
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
|
||||||
|
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||||
|
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||||
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||||
|
input_node_names_ref = "inp1,inp2,inp3"
|
||||||
|
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||||
|
for i in placeholder_values_ref.keys():
|
||||||
|
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||||
|
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||||
|
|
||||||
|
def test_get_shapes_several_inputs_several_partial_shapes5(self):
|
||||||
|
# some values for freezing specified using --freeze_placeholder_with_value
|
||||||
|
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
|
||||||
|
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
|
||||||
|
argv_freeze_placeholder_with_value = "inp2->[5.0 7.0 3.0],inp4->[100.0 200.0]"
|
||||||
|
|
||||||
|
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||||
|
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5)}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
|
||||||
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
|
||||||
|
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
|
||||||
|
input_node_names_ref = "inp1,inp2,inp3"
|
||||||
|
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
|
||||||
|
for i in placeholder_values_ref.keys():
|
||||||
|
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||||
|
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||||
|
|
||||||
|
def test_get_shapes_several_inputs_several_partial_shapes6(self):
|
||||||
|
# 0D value for freezing specified using --input command line parameter without shape
|
||||||
|
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3->False"
|
||||||
|
result, _ = get_placeholder_shapes(argv_input, None)
|
||||||
|
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||||
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False}
|
||||||
|
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||||
|
for i in placeholder_values_ref.keys():
|
||||||
|
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||||
|
|
||||||
|
def test_get_shapes_several_inputs_several_partial_shapes7(self):
|
||||||
|
# 0D shape and value for freezing specified using --input command line parameter
|
||||||
|
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[]->True"
|
||||||
|
result, _ = get_placeholder_shapes(argv_input, None)
|
||||||
|
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape}
|
||||||
|
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||||
|
for i in exp_res.keys():
|
||||||
|
npt.assert_array_equal(result[i], exp_res[i])
|
||||||
|
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||||
|
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True}
|
||||||
|
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||||
|
for i in placeholder_values_ref.keys():
|
||||||
|
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||||
|
|
||||||
|
def test_get_shapes_and_data_types_partial_shape_with_input_port(self):
|
||||||
|
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],0:inp2[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||||
|
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||||
|
ref_result_shapes = {'inp1:1': np.array([3, 1]), '0:inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
|
||||||
|
ref_result_data_types = {'0:inp2': np.int32, 'inp3:4': np.float32}
|
||||||
|
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||||
|
for i in ref_result_shapes.keys():
|
||||||
|
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||||
|
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||||
|
for i in ref_result_data_types.keys():
|
||||||
|
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||||
|
|
||||||
|
def test_get_shapes_and_data_types_partial_shape_with_output_port(self):
|
||||||
|
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],inp2:3[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||||
|
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||||
|
ref_result_shapes = {'inp1:1': np.array([3, 1]), 'inp2:3': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
|
||||||
|
ref_result_data_types = {'inp2:3': np.int32, 'inp3:4': np.float32}
|
||||||
|
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||||
|
for i in ref_result_shapes.keys():
|
||||||
|
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||||
|
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||||
|
for i in ref_result_data_types.keys():
|
||||||
|
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||||
|
|
||||||
|
def test_partial_shapes_negative_case(self):
|
||||||
|
argv_input = "inp1"
|
||||||
|
input_shapes = "[6754fg..23ed]"
|
||||||
|
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||||
|
|
||||||
|
def test_partial_shapes_freeze_dynamic_negative_case1(self):
|
||||||
|
argv_input = "inp1:1[3 1..10]->[1.0 2.0 3.0]"
|
||||||
|
self.assertRaises(Error, get_placeholder_shapes, argv_input, "")
|
||||||
|
|
||||||
|
def test_partial_shapes_freeze_dynamic_negative_case2(self):
|
||||||
|
argv_input = "inp1:1[1 2 -1]->[1.0 2.0 3.0]"
|
||||||
|
self.assertRaises(Error, get_placeholder_shapes, argv_input, "")
|
||||||
|
|
||||||
|
def test_partial_shapes_freeze_dynamic_negative_case3(self):
|
||||||
|
# some values for freezing specified using --freeze_placeholder_with_value
|
||||||
|
argv_input = "inp1->[1.0 2.0 3.0]"
|
||||||
|
input_shapes = "[3,1..10]"
|
||||||
|
self.assertRaises(Error, get_placeholder_shapes, argv_input, input_shapes)
|
||||||
|
|
||||||
|
|
||||||
class TestModelNameParsing(unittest.TestCase):
|
class TestModelNameParsing(unittest.TestCase):
|
||||||
def test_model_name_ideal(self):
|
def test_model_name_ideal(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user