Preserve input/output order for ONNX. (#9352)
* Preserving of input/output indices for ONNX. * Fixed checks. * Fixed for case of multiple outputs of node before Result. * Added test for multiple tensor names before Result. * Multiple tensor names before Result fix. * Added order alignment for user input/output. * Extended for case of input names in Parameter tensor list. * Fixed unit tests. * Corrected help. * Small correction. * Code refactoring. * Temporarily reverted refactor. * Fixed wrong changes. * Fixed wrong changes. * Returned reverted refactoring. * Removed inputs_list from serializing.
This commit is contained in:
parent
413fee2a86
commit
ce533fc287
@ -108,7 +108,8 @@ Framework-agnostic parameters:
|
||||
--log_level {CRITICAL,ERROR,WARN,WARNING,INFO,DEBUG,NOTSET}
|
||||
Logger level
|
||||
--input INPUT Quoted list of comma-separated input nodes names with shapes,
|
||||
data types, and values for freezing. The shape and value are
|
||||
data types, and values for freezing. The order of inputs in converted
|
||||
model is the same as order of specified operation names. The shape and value are
|
||||
specified as space-separated lists. The data type of input
|
||||
node is specified in braces and can have one of the values:
|
||||
f64 (float64), f32 (float32), f16 (float16), i64 (int64),
|
||||
@ -127,6 +128,8 @@ Framework-agnostic parameters:
|
||||
"0:node_name1[3 4],node_name2:1[2]{i32}->[20 15]".
|
||||
--output OUTPUT The name of the output operation of the model. For
|
||||
TensorFlow*, do not add :0 to this name.
|
||||
The order of outputs in converted model is the same as order of
|
||||
specified operation names.
|
||||
--mean_values MEAN_VALUES, -ms MEAN_VALUES
|
||||
Mean values to be used for the input image per
|
||||
channel. Values to be provided in the (R,G,B) or
|
||||
|
@ -2,15 +2,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
import xml.etree.ElementTree as ET
|
||||
from openvino.tools.mo.utils.ir_engine.ir_engine import IREngine
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from common.constants import test_device, test_precision
|
||||
from common.layer_utils import IEInfer
|
||||
from openvino.tools.mo.utils.ir_engine.ir_engine import IREngine
|
||||
|
||||
from common.utils.common_utils import generate_ir
|
||||
from common.utils.parsers import mapping_parser
|
||||
|
||||
@ -113,6 +114,31 @@ class CommonLayerTest:
|
||||
mapping_dict=mapping_dict, framework_eps=fw_eps), \
|
||||
"Comparing with Framework failed: ie_res={}; framework_res={}.".format(infer_res, fw_res)
|
||||
|
||||
if len(inputs_dict.keys()) > 1 or len(infer_res.keys()) > 1:
|
||||
tree = ET.parse(path_to_xml)
|
||||
# findall returns elements in document order, this order should be the same as
|
||||
# order of inputs/outputs in original model
|
||||
inputs_ie = [child for child in tree.findall('.//layer[@type="Parameter"]')]
|
||||
outputs_ie = [child for child in tree.findall('.//layer[@type="Result"]')]
|
||||
|
||||
if 'input_names' in kwargs:
|
||||
input_names = kwargs['input_names']
|
||||
for i, input_name in enumerate(input_names):
|
||||
assert inputs_ie[i].attrib['name'] == input_name, \
|
||||
'Input order does not match framework order. Input with index {} is {}, ' \
|
||||
'but expected {}'.format(i, inputs_ie[i].attrib['name'], input_name)
|
||||
|
||||
if 'output_names' in kwargs:
|
||||
output_names = kwargs['output_names']
|
||||
for i, output_name in enumerate(output_names):
|
||||
output_name_ie = outputs_ie[i].attrib['name']
|
||||
output_without_sink_port = re.sub(r'\/sink_port_.', '', output_name_ie)
|
||||
|
||||
assert output_without_sink_port == output_name, \
|
||||
'Output order does not match framework order. Output with index {} is {}, ' \
|
||||
'but expected {}'.format(i, output_without_sink_port, output_name)
|
||||
|
||||
|
||||
# Feed dict for each input is filled with random number.
|
||||
# It is possible to redefine this function and generate your own input
|
||||
def _prepare_input(self, inputs_dict):
|
||||
@ -124,10 +150,13 @@ class CommonLayerTest:
|
||||
is_ok = True
|
||||
from common.utils.common_utils import allclose
|
||||
for framework_out_name in framework_res:
|
||||
if framework_out_name not in mapping_dict:
|
||||
raise RuntimeError("Output {} not found in mapping file!".format(framework_out_name))
|
||||
|
||||
ie_out_name = mapping_dict[framework_out_name]
|
||||
if framework_out_name not in list(infer_res.keys()):
|
||||
if framework_out_name not in mapping_dict:
|
||||
raise RuntimeError("Output {} not found in mapping file!".format(framework_out_name))
|
||||
ie_out_name = mapping_dict[framework_out_name]
|
||||
else:
|
||||
ie_out_name = framework_out_name
|
||||
|
||||
if not allclose(infer_res[ie_out_name], framework_res[framework_out_name], atol=framework_eps,
|
||||
rtol=framework_eps):
|
||||
|
@ -124,6 +124,48 @@ class TestConcat(Caffe2OnnxLayerTest):
|
||||
|
||||
return onnx_net, ref_net
|
||||
|
||||
def create_concat_net(self, input_shape, output_shape, axis, input_names, ir_version):
|
||||
"""
|
||||
ONNX net IR net
|
||||
|
||||
Input1----->Concat------>Output => Input1--->Concat------>Output
|
||||
Input2-----' Input2---'
|
||||
Input3-----' Input3---'
|
||||
... ...
|
||||
"""
|
||||
#
|
||||
# Create ONNX model
|
||||
#
|
||||
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
import numpy as np
|
||||
|
||||
shape = input_shape
|
||||
inputs_list = []
|
||||
for input_name in input_names:
|
||||
inputs_list.append(helper.make_tensor_value_info(input_name, TensorProto.FLOAT, shape))
|
||||
|
||||
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape)
|
||||
|
||||
node = onnx.helper.make_node('Concat', inputs=input_names, outputs=['output'], axis=axis)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[node],
|
||||
'concat_model',
|
||||
inputs_list,
|
||||
[output],
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_concat_model')
|
||||
|
||||
ref_net = None
|
||||
|
||||
return onnx_net, ref_net
|
||||
|
||||
test_data_3D = [
|
||||
dict(input_shape=[1, 50, 50],
|
||||
output_shape=[2, 50, 50],
|
||||
@ -194,6 +236,21 @@ class TestConcat(Caffe2OnnxLayerTest):
|
||||
axis=4),
|
||||
]
|
||||
|
||||
test_concat_inputs_order_params = [
|
||||
dict(input_shape=[6],
|
||||
output_shape=[30],
|
||||
axis=0,
|
||||
input_names=['a', 't', 'm', 'p', 'e']),
|
||||
dict(input_shape=[5, 2],
|
||||
output_shape=[5, 8],
|
||||
axis=1,
|
||||
input_names=['inp2', 'inp1', 'inp5', 'inp4']),
|
||||
dict(input_shape=[6, 2, 5, 3],
|
||||
output_shape=[6, 2, 20, 3],
|
||||
axis=2,
|
||||
input_names=['n', 's', 'c', 'x']),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_3D)
|
||||
@pytest.mark.nightly
|
||||
def test_concat_3D_const(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
@ -223,3 +280,9 @@ class TestConcat(Caffe2OnnxLayerTest):
|
||||
def test_concat_5D_const(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_concat_net_const(**params, ir_version=ir_version), ie_device, precision, ir_version,
|
||||
temp_dir=temp_dir)
|
||||
|
||||
@pytest.mark.parametrize("params", test_concat_inputs_order_params)
|
||||
@pytest.mark.nightly
|
||||
def test_concat_inputs_order(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_concat_net(**params, ir_version=ir_version), ie_device=ie_device, precision=precision,
|
||||
ir_version=ir_version, temp_dir=temp_dir, input_names=params['input_names'])
|
||||
|
@ -30,6 +30,68 @@ test_data_5D = [
|
||||
[1, 50, 10, 80, 60]], axis=2),
|
||||
dict(input_shape=[1, 50, 50, 80, 60], output_shapes=[[1, 25, 50, 80, 60], [1, 25, 50, 80, 60]], axis=1)]
|
||||
|
||||
test_multiple_out = [
|
||||
dict(input_shape=[3, 10, 10],
|
||||
output_shapes=[[1, 10, 10],
|
||||
[1, 10, 10],
|
||||
[1, 10, 10]],
|
||||
axis=0,
|
||||
output_names=['h', 'b', 'l']),
|
||||
dict(input_shape=[1, 50, 50, 80, 60],
|
||||
output_shapes=[[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60]],
|
||||
axis=2,
|
||||
output_names=['k', 'p', 'a', 'r', 's']),
|
||||
dict(input_shape=[1, 4, 3],
|
||||
output_shapes=[[1, 1, 3],
|
||||
[1, 1, 3],
|
||||
[1, 1, 3],
|
||||
[1, 1, 3],
|
||||
[1, 1, 3]],
|
||||
axis=1,
|
||||
output_names=['inp4', 'inp1', 'inp3', 'inp2'])
|
||||
]
|
||||
|
||||
test_multiple_out_with_add = [
|
||||
dict(input_shape=[3, 10, 10],
|
||||
output_shapes=[[1, 10, 10],
|
||||
[1, 10, 10],
|
||||
[1, 10, 10]],
|
||||
axis=0,
|
||||
output_names=['h', 'b', 'l', 'c', 'p']
|
||||
),
|
||||
dict(input_shape=[1, 50, 50, 80, 60],
|
||||
output_shapes=[[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60],
|
||||
[1, 50, 10, 80, 60]],
|
||||
axis=2,
|
||||
output_names=['k', 'p', 'a', 'r', 's', 'l', 'w']),
|
||||
dict(input_shape=[1, 4, 3],
|
||||
output_shapes=[[1, 1, 3],
|
||||
[1, 1, 3],
|
||||
[1, 1, 3],
|
||||
[1, 1, 3],
|
||||
[1, 1, 3]],
|
||||
axis=1,
|
||||
output_names=['inp4', 'inp1', 'inp5', 'inp2', 'inp3', 'inp33'])
|
||||
]
|
||||
|
||||
test_multiple_out_with_identity = [
|
||||
dict(input_shape=[3, 10, 10],
|
||||
output_shapes=[[1, 10, 10],
|
||||
[1, 10, 10],
|
||||
[1, 10, 10]],
|
||||
axis=0,
|
||||
split_out_names=['h', 'b', 'l'],
|
||||
identity_names=['i1', 'i2', 'i3'],
|
||||
output_names=['h', 'b', 'l', 'i3'],
|
||||
),
|
||||
]
|
||||
|
||||
class TestSplitConcat(Caffe2OnnxLayerTest):
|
||||
# TODO Add test with default values (axis=0)
|
||||
@ -288,6 +350,169 @@ class TestSplit(Caffe2OnnxLayerTest):
|
||||
|
||||
return onnx_net, ref_net
|
||||
|
||||
|
||||
def create_split_net_ordered_outputs(self, input_shape, output_shapes, axis, output_names, ir_version):
|
||||
"""
|
||||
ONNX net IR net
|
||||
|
||||
Input->Split->Output1 => Input->Split->Output1
|
||||
->Output2 => ->Output2
|
||||
->Output3 => ->Output3
|
||||
|
||||
"""
|
||||
#
|
||||
# Create ONNX model
|
||||
#
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
shape = input_shape
|
||||
|
||||
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, shape)
|
||||
|
||||
output_list = []
|
||||
for i, output_name in enumerate(output_names):
|
||||
output_list.append(helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shapes[i]))
|
||||
|
||||
node = onnx.helper.make_node('Split', inputs=['input'], outputs=output_names, axis=axis)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[node],
|
||||
'split_model',
|
||||
[input],
|
||||
output_list,
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_split_model_outputs_order')
|
||||
|
||||
ref_net = None
|
||||
|
||||
return onnx_net, ref_net
|
||||
|
||||
def create_split_net_ordered_outputs_with_add(self, input_shape, output_shapes, axis, output_names, ir_version):
|
||||
"""
|
||||
This test checks the case when graph has a node that is connected with Result and some other operation
|
||||
from single output port.
|
||||
|
||||
ONNX net IR net
|
||||
|
||||
Input Input
|
||||
| |
|
||||
Split Split
|
||||
| | ... | | | .... |
|
||||
Ouput1 Output2 OutputN | | Result_N
|
||||
\ / /\ / \
|
||||
Add / Add \
|
||||
Result_0 | Result_1
|
||||
Result_N+1
|
||||
|
||||
"""
|
||||
#
|
||||
# Create ONNX model
|
||||
#
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
shape = input_shape
|
||||
|
||||
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, shape)
|
||||
|
||||
add_output_name1 = output_names[len(output_names)-2]
|
||||
add_output_name2 = output_names[len(output_names)-1]
|
||||
outputs_without_add = output_names[:len(output_names)-2]
|
||||
|
||||
output_list = []
|
||||
for i, output_name in enumerate(outputs_without_add):
|
||||
output_list.append(helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shapes[i]))
|
||||
|
||||
node = onnx.helper.make_node('Split', inputs=['input'], outputs=outputs_without_add, axis=axis)
|
||||
node_add1 = helper.make_node(
|
||||
'Add',
|
||||
inputs=[outputs_without_add[1], outputs_without_add[2]],
|
||||
outputs=[add_output_name1]
|
||||
)
|
||||
node_add2 = helper.make_node(
|
||||
'Add',
|
||||
inputs=[add_output_name1, outputs_without_add[2]],
|
||||
outputs=[add_output_name2]
|
||||
)
|
||||
|
||||
output_list = output_list + [helper.make_tensor_value_info(add_output_name1, TensorProto.FLOAT, output_shapes[0])] + [helper.make_tensor_value_info(add_output_name2, TensorProto.FLOAT, output_shapes[0])]
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[node, node_add1, node_add2],
|
||||
'split_model',
|
||||
[input],
|
||||
output_list,
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_split_model_outputs_order')
|
||||
|
||||
ref_net = None
|
||||
|
||||
return onnx_net, ref_net
|
||||
|
||||
def create_split_net_ordered_outputs_multiple_tensor_names(self, input_shape, output_shapes, axis, split_out_names, identity_names, output_names, ir_version):
|
||||
"""
|
||||
This test checks the case of multiple tensor names on connection incoming to Result. In this case
|
||||
Result name is equal to one of tensor names from the list.
|
||||
|
||||
ONNX net IR net
|
||||
|
||||
Input->Split->Identity1->Identity2->Identity3 -> Output1
|
||||
->Output2
|
||||
->Output3
|
||||
|
||||
|
||||
IR net
|
||||
|
||||
Input->Split->Result1 - this connection has tensor names from Split, Identity1, Identity2, Identity3 ops
|
||||
->Result2
|
||||
->Result3
|
||||
|
||||
"""
|
||||
#
|
||||
# Create ONNX model
|
||||
#
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
shape = input_shape
|
||||
|
||||
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, shape)
|
||||
|
||||
output_list = []
|
||||
for i, output_name in enumerate(split_out_names):
|
||||
output_list.append(helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shapes[i]))
|
||||
output_list.append(helper.make_tensor_value_info(identity_names[2], TensorProto.FLOAT, output_shapes[i]))
|
||||
|
||||
node = onnx.helper.make_node('Split', inputs=['input'], outputs=split_out_names, axis=axis)
|
||||
identity1 = onnx.helper.make_node('Identity', inputs=[split_out_names[0]], outputs=[identity_names[0]])
|
||||
identity2 = onnx.helper.make_node('Identity', inputs=[identity_names[0]], outputs=[identity_names[1]])
|
||||
identity3 = onnx.helper.make_node('Identity', inputs=[identity_names[1]], outputs=[identity_names[2]])
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[node, identity1, identity2, identity3],
|
||||
'split_model',
|
||||
[input],
|
||||
output_list,
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_split_model_outputs_order')
|
||||
|
||||
ref_net = None
|
||||
|
||||
return onnx_net, ref_net
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_3D)
|
||||
@pytest.mark.nightly
|
||||
def test_split_3D(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
@ -305,3 +530,28 @@ class TestSplit(Caffe2OnnxLayerTest):
|
||||
def test_split_5D(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_split_net(**params, ir_version=ir_version), ie_device, precision, ir_version,
|
||||
temp_dir=temp_dir)
|
||||
|
||||
@pytest.mark.parametrize("params", test_multiple_out)
|
||||
def test_split_outputs_order(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_split_net_ordered_outputs(**params, ir_version=ir_version), ie_device, precision,
|
||||
ir_version, temp_dir=temp_dir, output_names=params['output_names'])
|
||||
|
||||
@pytest.mark.parametrize("params", test_multiple_out_with_add)
|
||||
def test_split_outputs_order_multiple_connection_before_result_case(self,
|
||||
params,
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
temp_dir):
|
||||
self._test(*self.create_split_net_ordered_outputs_with_add(**params, ir_version=ir_version), ie_device,
|
||||
precision, ir_version, temp_dir=temp_dir, output_names=params['output_names'])
|
||||
|
||||
@pytest.mark.parametrize("params", test_multiple_out_with_identity)
|
||||
def test_split_outputs_order_multiple_tensors_before_result_case(self,
|
||||
params,
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
temp_dir):
|
||||
self._test(*self.create_split_net_ordered_outputs_multiple_tensor_names(**params, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir, output_names=params['output_names'])
|
||||
|
@ -415,29 +415,144 @@ def add_meta_data(net: Element, meta_info: dict):
|
||||
meta = SubElement(net, 'meta_data')
|
||||
SubElement(meta, 'MO_version').set('value', get_version())
|
||||
parameters = SubElement(meta, 'cli_parameters')
|
||||
if 'inputs_list' in meta_info:
|
||||
del meta_info['inputs_list']
|
||||
[SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if
|
||||
key not in ('unset', 'quantization_parameters')]
|
||||
if 'unset' in meta_info:
|
||||
SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset'])))
|
||||
|
||||
|
||||
def serialize_node(graph: Graph, node: Node, layers: SubElement, edges: SubElement, unsupported: UnsupportedOps):
|
||||
if node.kind == 'op' and (not node.has('type') or node.type is None):
|
||||
unsupported.add(node)
|
||||
return
|
||||
if not node.has('IE'):
|
||||
return
|
||||
try:
|
||||
serialize_node_attributes(graph, node, node.IE, layers, edges, unsupported)
|
||||
except Error as e:
|
||||
raise Error(str(e).replace('<SUB-ELEMENT>', '{} (id = {})'.format(node.soft_get('name'), node.id))) from e
|
||||
|
||||
|
||||
def get_tensor_names_of_result_node(graph):
|
||||
result_nodes = graph.get_op_nodes(type='Result')
|
||||
result_names_to_tensor_names = {}
|
||||
for res_node in result_nodes:
|
||||
|
||||
# After port renumbering port/connection API is not applicable
|
||||
assert len(res_node.in_nodes()) > 0, \
|
||||
"Result node with name {} has no input node.".format(res_node.soft_get('name'))
|
||||
res_data_node = res_node.in_node(0)
|
||||
assert len(res_data_node.in_nodes()) > 0, \
|
||||
"Data node of Result with name {} has no input node.".format(res_node.soft_get('name'))
|
||||
res_in_node = res_data_node.in_node(0)
|
||||
|
||||
# We cannot use out_ports() after port renumbering
|
||||
for v, d in res_in_node.get_sorted_outputs():
|
||||
port_id = d['out'] - len(res_in_node.in_nodes()) if res_in_node.type != 'Const' else d['out']
|
||||
tensor_names = res_in_node.out_port(port_id).get_tensor_names(port_renumber=True)
|
||||
result_names_to_tensor_names[res_node.soft_get('name')] = tensor_names
|
||||
return result_names_to_tensor_names
|
||||
|
||||
|
||||
def find_result_node_by_name(output_name, result_nodes, result_names_to_tensor_names):
|
||||
for res_node in result_nodes:
|
||||
res_name = res_node.soft_get('name')
|
||||
tensor_names = result_names_to_tensor_names[res_name]
|
||||
if output_name in tensor_names:
|
||||
# In this case output tensor name is in tensor names list of previous op
|
||||
return res_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def serialize_network(graph, net_element, unsupported):
|
||||
layers = SubElement(net_element, 'layers')
|
||||
edges = SubElement(net_element, 'edges')
|
||||
if graph is None:
|
||||
return
|
||||
nodes = sorted(graph.nodes())
|
||||
|
||||
result_nodes = graph.get_op_nodes(type='Result')
|
||||
result_names_to_tensor_names = get_tensor_names_of_result_node(graph)
|
||||
|
||||
ordered_results = []
|
||||
for output_name in graph.outputs_order:
|
||||
node = graph.get_op_nodes(name=output_name)
|
||||
|
||||
if len(node) == 0:
|
||||
# As graph does not contain node with name=output_name
|
||||
# in the following code we look for output_name among tensor names
|
||||
# incoming to Result nodes
|
||||
found_result_name = find_result_node_by_name(output_name, result_nodes, result_names_to_tensor_names)
|
||||
|
||||
if found_result_name is not None:
|
||||
ordered_results.append(found_result_name)
|
||||
else:
|
||||
log.warning("Output node with name {} is not found in graph.".format(output_name))
|
||||
continue
|
||||
node = node[0]
|
||||
|
||||
# In this case Result node has the same name as output tensor
|
||||
if node.soft_get('type') == 'Result':
|
||||
ordered_results.append(node.soft_get('name'))
|
||||
continue
|
||||
|
||||
# Here output data node count is checked. Each port cannot have more than one data node.
|
||||
assert len(node.out_nodes()) == 1, "Incorrect graph. Non-Result node with name {} " \
|
||||
"has no output data node.".format(output_name)
|
||||
|
||||
# After port renumbering port/connection API is not applicable, and output port numbering
|
||||
# starts from len(node.in_nodes()).
|
||||
data_node = node.out_node(len(node.in_nodes()))
|
||||
|
||||
found_result = False
|
||||
for op_node in data_node.out_nodes():
|
||||
if op_node.soft_get('type') == 'Result':
|
||||
found_result = True
|
||||
ordered_results.append(op_node.soft_get('name'))
|
||||
break
|
||||
|
||||
if not found_result:
|
||||
log.warning("Node that expected to be output with name {} is not connected with Result node.".format(output_name))
|
||||
|
||||
param_nodes = graph.get_op_nodes(type='Parameter')
|
||||
serialized_inputs = []
|
||||
for input_name in graph.inputs_order:
|
||||
node = graph.get_op_nodes(name=input_name)
|
||||
if len(node) != 0:
|
||||
serialize_node(graph, node[0], layers, edges, unsupported)
|
||||
serialized_inputs.append(input_name)
|
||||
continue
|
||||
found_tensor_name = False
|
||||
for param_node in param_nodes:
|
||||
param_name = param_node.soft_get('name')
|
||||
if not param_node.is_out_port_connected(0):
|
||||
continue
|
||||
tensor_names = param_node.out_port(0).get_tensor_names(port_renumber=True)
|
||||
if input_name in tensor_names:
|
||||
# In this case input name is in tensor names list of Parameter op
|
||||
serialize_node(graph, param_node, layers, edges, unsupported)
|
||||
serialized_inputs.append(param_name)
|
||||
found_tensor_name = True
|
||||
break
|
||||
|
||||
if not found_tensor_name:
|
||||
log.warning("Input node with name {} is not found in graph.".format(param_name))
|
||||
|
||||
for node in nodes:
|
||||
node = Node(graph, node)
|
||||
if node.kind == 'op' and (not node.has('type') or node.type is None):
|
||||
unsupported.add(node)
|
||||
if node.soft_get('name') in serialized_inputs:
|
||||
continue
|
||||
if not node.has('IE'):
|
||||
if node.soft_get('name') in ordered_results:
|
||||
continue
|
||||
try:
|
||||
serialize_node_attributes(graph, node, node.IE, layers, edges, unsupported)
|
||||
except Error as e:
|
||||
raise Error(str(e).replace('<SUB-ELEMENT>', '{} (id = {})'.format(node.soft_get('name'), node.id))) from e
|
||||
serialize_node(graph, node, layers, edges, unsupported)
|
||||
|
||||
for output_name in ordered_results:
|
||||
node = graph.get_op_nodes(name=output_name)
|
||||
assert len(node) == 1, "Output node with name {} is not found in graph.".format(output_name)
|
||||
serialize_node(graph, node[0], layers, edges, unsupported)
|
||||
|
||||
|
||||
def generate_ie_ir(graph: Graph, file_name: str, input_names: tuple = (), mean_offset: tuple = (),
|
||||
|
@ -58,6 +58,12 @@ def protobuf2nx(graph: Graph, pb):
|
||||
graph_pb = pb.graph
|
||||
add_initializers_and_inputs_to_graph(graph, graph_pb, data_nodes_map)
|
||||
|
||||
# Preserve inputs order
|
||||
graph.inputs_order = []
|
||||
for inp in graph_pb.input:
|
||||
name = str(inp.name)
|
||||
graph.inputs_order.append(name)
|
||||
|
||||
output_ids = []
|
||||
for outp in graph_pb.output:
|
||||
name = str(outp.name)
|
||||
@ -70,6 +76,9 @@ def protobuf2nx(graph: Graph, pb):
|
||||
graph.add_node(name, kind='op', op='FakeOutput', pb=outp)
|
||||
output_ids.append(name)
|
||||
|
||||
# Preserve outputs order
|
||||
graph.outputs_order = output_ids
|
||||
|
||||
# Go through all nodes in the original model order (because data nodes are defined on-the-fly and order is
|
||||
# important)
|
||||
for node in graph_pb.node:
|
||||
|
@ -32,6 +32,11 @@ class UserDataRepack(FrontReplacementPattern):
|
||||
graph.graph['packed_outputs'] = packed_outputs
|
||||
graph.graph['freeze_placeholder'] = freeze_placeholder
|
||||
|
||||
if argv.inputs_list is not None and isinstance(argv.inputs_list, list) and len(argv.inputs_list) > 0:
|
||||
graph.inputs_order = argv.inputs_list
|
||||
if argv.output is not None and isinstance(argv.output, list) and len(argv.output) > 0:
|
||||
graph.outputs_order = argv.output
|
||||
|
||||
inputs = list(packed_user_shapes.keys()) \
|
||||
if packed_user_shapes is not None and isinstance(packed_user_shapes, dict) else None
|
||||
graph.graph['inputs'] = inputs # save user defined inputs for other extensions
|
||||
|
@ -566,6 +566,8 @@ class Graph(nx.MultiDiGraph):
|
||||
|
||||
unique_id_count = 0
|
||||
op_names_statistic = collections.Counter()
|
||||
inputs_order = []
|
||||
outputs_order = []
|
||||
|
||||
# SAFE API DESCRIPTION
|
||||
# all provided methods below are designed to be more safe and convenient
|
||||
|
@ -272,8 +272,9 @@ def arguments_post_parsing(argv: argparse.Namespace):
|
||||
|
||||
argv.output = argv.output.split(',') if argv.output else None
|
||||
|
||||
argv.placeholder_shapes, argv.placeholder_data_types = get_placeholder_shapes(argv.input, argv.input_shape,
|
||||
argv.batch)
|
||||
inputs_list, argv.placeholder_shapes, argv.placeholder_data_types = get_placeholder_shapes(
|
||||
argv.input, argv.input_shape, argv.batch)
|
||||
argv.inputs_list = inputs_list
|
||||
|
||||
mean_values = parse_tuple_pairs(argv.mean_values)
|
||||
scale_values = parse_tuple_pairs(argv.scale_values)
|
||||
|
@ -291,7 +291,8 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
||||
default='ERROR')
|
||||
common_group.add_argument('--input',
|
||||
help='Quoted list of comma-separated input nodes names with shapes, data types, '
|
||||
'and values for freezing. The shape and value are specified as space-separated '
|
||||
'and values for freezing. The order of inputs in converted model is the same as '
|
||||
'order of specified operation names. The shape and value are specified as space-separated '
|
||||
'lists. The data type of input node is specified in braces and '
|
||||
'can have one of the values: f64 (float64), f32 (float32), f16 (float16), '
|
||||
'i64 (int64), i32 (int32), u8 (uint8), boolean (bool). Data type is optional. '
|
||||
@ -308,7 +309,9 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
||||
'"0:node_name1[3 4],node_name2:1[2]{i32}->[20 15]".')
|
||||
common_group.add_argument('--output',
|
||||
help='The name of the output operation of the model. ' +
|
||||
'For TensorFlow*, do not add :0 to this name.')
|
||||
'For TensorFlow*, do not add :0 to this name.'
|
||||
'The order of outputs in converted model is the same as order of '
|
||||
'specified operation names.')
|
||||
common_group.add_argument('--mean_values', '-ms',
|
||||
help='Mean values to be used for the input image per channel. ' +
|
||||
'Values to be provided in the (R,G,B) or [R,G,B] format. ' +
|
||||
@ -1130,10 +1133,12 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
placeholder_shapes = dict()
|
||||
placeholder_data_types = dict()
|
||||
are_shapes_specified_through_input = False
|
||||
inputs_list = list()
|
||||
if argv_input:
|
||||
for input_value in argv_input.split(','):
|
||||
node_name, shape, _, data_type = parse_input_value(input_value)
|
||||
placeholder_shapes[node_name] = shape
|
||||
inputs_list.append(node_name)
|
||||
if data_type is not None:
|
||||
placeholder_data_types[node_name] = data_type
|
||||
if shape is not None:
|
||||
@ -1148,10 +1153,11 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
"parameter is allowed.")
|
||||
|
||||
if are_shapes_specified_through_input:
|
||||
return placeholder_shapes, placeholder_data_types
|
||||
return inputs_list, placeholder_shapes, placeholder_data_types
|
||||
|
||||
shapes = list()
|
||||
inputs = list()
|
||||
inputs_list = list()
|
||||
placeholder_shapes = None
|
||||
|
||||
range_reg = r'([0-9]*\.\.[0-9]*)'
|
||||
@ -1184,8 +1190,10 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
shapes)))
|
||||
for inp in inputs:
|
||||
if '->' not in inp:
|
||||
inputs_list.append(inp)
|
||||
continue
|
||||
shape = placeholder_shapes[inp.split('->')[0]]
|
||||
inputs_list.append(inp.split('->')[0])
|
||||
|
||||
if shape is None:
|
||||
continue
|
||||
@ -1196,7 +1204,7 @@ def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=No
|
||||
elif argv_input:
|
||||
raise Error('Please provide each input layers with an input layer shape. ' + refer_to_faq_msg(58))
|
||||
|
||||
return placeholder_shapes, placeholder_data_types
|
||||
return inputs_list, placeholder_shapes, placeholder_data_types
|
||||
|
||||
|
||||
def parse_tuple_pairs(argv_values: str):
|
||||
|
@ -448,16 +448,17 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes(self):
|
||||
argv_input = "inp1,inp2"
|
||||
input_shapes = "(1,22,333,123), (-1,45,7,1)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_shapes2(self):
|
||||
# shapes specified using --input command line parameter and no values
|
||||
argv_input = "inp1[1 22 333 123],inp2[-1 45 7 1]"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -466,15 +467,17 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
placeholder_values_ref = {}
|
||||
input_node_names_ref = "inp1,inp2"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_and_freezing_with_scalar_and_without_shapes_in_input(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1,inp2->157"
|
||||
result_shapes, _ = get_placeholder_shapes(argv_input, None)
|
||||
input_list, result_shapes, _ = get_placeholder_shapes(argv_input, None)
|
||||
ref_shapes = {'inp1': None, 'inp2': None}
|
||||
self.assertEqual(list(ref_shapes.keys()), list(result_shapes.keys()))
|
||||
self.assertEqual(input_list, ["inp1","inp2"])
|
||||
for i in ref_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_shapes[i])
|
||||
|
||||
@ -488,11 +491,12 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_and_freezing_with_scalar(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1,inp2[]->157"
|
||||
result_shapes, _ = get_placeholder_shapes(argv_input, None)
|
||||
input_list, result_shapes, _ = get_placeholder_shapes(argv_input, None)
|
||||
ref_shapes = {'inp1': None, 'inp2': ()}
|
||||
self.assertEqual(list(ref_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_shapes[i])
|
||||
self.assertEqual(input_list, ["inp1","inp2"])
|
||||
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp2': 157}
|
||||
@ -504,7 +508,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes3(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3],inp3[5]->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
input_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -514,6 +518,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(input_list, ["inp1","inp2","inp3"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
@ -521,7 +526,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
# shapes specified using --input_shape and values for freezing using --input command line parameter
|
||||
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
|
||||
input_shapes = "(3,1), (3,2,3), (5)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -531,6 +536,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||
@ -541,9 +547,10 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
input_shapes = "(3,1), (3,2,3), (5)"
|
||||
argv_freeze_placeholder_with_value = "inp2->[5.0 7.0 3.0],inp4->[100.0 200.0]"
|
||||
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input,
|
||||
@ -560,9 +567,10 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes6(self):
|
||||
# 0D value for freezing specified using --input command line parameter without shape
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3],inp3->False"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': None}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
@ -574,9 +582,10 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_shapes7(self):
|
||||
# 0D shape and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3],inp3[]->True"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
@ -587,43 +596,46 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
|
||||
def test_get_shapes_and_data_types1(self):
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{i32},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
|
||||
ref_result_data_types = {'inp2': np.int32, 'inp3': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["inp1","inp2","inp3"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_get_shapes_and_data_types_with_input_ports(self):
|
||||
argv_input = "1:inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{i32},0:inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'1:inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), '0:inp3': np.array([5])}
|
||||
ref_result_data_types = {'inp2': np.int32, '0:inp3': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["1:inp1","inp2","0:inp3"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_get_shapes_and_data_types_with_output_ports(self):
|
||||
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'inp1:1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3:4': np.array([5])}
|
||||
ref_result_data_types = {'inp2': np.int32, 'inp3:4': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["inp1:1","inp2","inp3:4"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_get_shapes_and_data_types_shape_only(self):
|
||||
argv_input = "placeholder1[3 1],placeholder2,placeholder3"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'placeholder1': np.array([3, 1]), 'placeholder2': None,
|
||||
'placeholder3': None}
|
||||
ref_result_data_types = {}
|
||||
@ -631,12 +643,13 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["placeholder1","placeholder2","placeholder3"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_get_shapes_and_data_types_shape_with_ports_only(self):
|
||||
argv_input = "placeholder1:4[3 1],placeholder2,2:placeholder3"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'placeholder1:4': np.array([3, 1]), 'placeholder2': None,
|
||||
'2:placeholder3': None}
|
||||
ref_result_data_types = {}
|
||||
@ -644,12 +657,13 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["placeholder1:4","placeholder2","2:placeholder3"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_get_shapes_and_data_types_when_no_freeze_value(self):
|
||||
argv_input = "placeholder1{i32}[3 1],placeholder2,placeholder3{i32}"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'placeholder1': np.array([3, 1]), 'placeholder2': None,
|
||||
'placeholder3': None}
|
||||
ref_result_data_types = {'placeholder1': np.int32, 'placeholder3': np.int32}
|
||||
@ -657,6 +671,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["placeholder1","placeholder2","placeholder3"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
@ -700,30 +715,31 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_one_input_one_shape(self):
|
||||
argv_input = "inp1"
|
||||
input_shapes = "(1,22,333,123)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': np.array([1, 22, 333, 123])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_no_input_no_shape(self):
|
||||
argv_input = ""
|
||||
input_shapes = ""
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = None
|
||||
assert np.array_equal(result, exp_res)
|
||||
|
||||
def test_get_shapes_no_input_one_shape(self):
|
||||
argv_input = ""
|
||||
input_shapes = "(12,4,1)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = np.array([12, 4, 1])
|
||||
assert np.array_equal(result, exp_res)
|
||||
|
||||
def test_get_shapes_no_input_one_shape2(self):
|
||||
argv_input = ""
|
||||
input_shapes = "[12,4,1]"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
_, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = np.array([12, 4, 1])
|
||||
assert np.array_equal(result, exp_res)
|
||||
|
||||
@ -735,9 +751,10 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_one_input_no_shape(self):
|
||||
argv_input = "inp1"
|
||||
input_shapes = ""
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
input_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': None}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(input_list, ["inp1"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
@ -794,9 +811,10 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_one_input_first_neg_shape1(self):
|
||||
argv_input = "inp1,inp2"
|
||||
input_shapes = "(-1,4,1),(4,6,8)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': np.array([-1, 4, 1]), 'inp2': np.array([4, 6, 8])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
@ -819,16 +837,17 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes(self):
|
||||
argv_input = "inp1,inp2"
|
||||
input_shapes = "(1,..22,1..100,?), (-1,45..,7,1)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': (1, (0, 22), (1, 100), -1), 'inp2': (-1, (45, np.iinfo(np.int64).max), 7, 1)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2"])
|
||||
for i in exp_res.keys():
|
||||
assert np.array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes2(self):
|
||||
# shapes specified using --input command line parameter and no values
|
||||
argv_input = "inp1[1 ? 50..100 123],inp2[-1 45.. ..7 1]"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (1, -1, (50, 100), 123), 'inp2': (-1, (45,np.iinfo(np.int64).max), (0, 7), 1)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -837,13 +856,14 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
placeholder_values_ref = {}
|
||||
input_node_names_ref = "inp1,inp2"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes3(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[5]->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5,)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -852,6 +872,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
@ -859,7 +880,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
# shapes specified using --input_shape and values for freezing using --input command line parameter
|
||||
argv_input = "inp1->[1.0 2.0 3.0],inp2,inp3->[1.0 1.0 2.0 3.0 5.0]"
|
||||
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5,)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -868,6 +889,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||
@ -878,7 +900,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
input_shapes = "(3,1), (3..,..2,5..10,?,-1), (5)"
|
||||
argv_freeze_placeholder_with_value = "inp2->[5.0 7.0 3.0],inp4->[100.0 200.0]"
|
||||
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': (5,)}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -888,6 +910,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||
@ -895,7 +918,7 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
def test_get_shapes_several_inputs_several_partial_shapes6(self):
|
||||
# 0D value for freezing specified using --input command line parameter without shape
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3->False"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': None}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -903,13 +926,14 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False}
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_partial_shapes7(self):
|
||||
# 0D shape and value for freezing specified using --input command line parameter
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3.. ..2 5..10 ? -1],inp3[]->True"
|
||||
result, _ = get_placeholder_shapes(argv_input, None)
|
||||
inputs_list, result, _ = get_placeholder_shapes(argv_input, None)
|
||||
exp_res = {'inp1': (3, 1), 'inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3': np.array(False).shape}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
@ -917,30 +941,33 @@ class TestShapesParsing(UnitTestWithMockedTelemetry):
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True}
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
self.assertEqual(inputs_list, ["inp1","inp2","inp3"])
|
||||
for i in placeholder_values_ref.keys():
|
||||
assert np.array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_input_port(self):
|
||||
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],0:inp2[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'inp1:1': np.array([3, 1]), '0:inp2': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
|
||||
ref_result_data_types = {'0:inp2': np.int32, 'inp3:4': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["inp1:1","0:inp2","inp3:4"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
def test_get_shapes_and_data_types_partial_shape_with_output_port(self):
|
||||
argv_input = "inp1:1[3 1]->[1.0 2.0 3.0],inp2:3[3.. ..2 5..10 ? -1]{i32},inp3:4[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
input_list, result_shapes, result_data_types = get_placeholder_shapes(argv_input, "")
|
||||
ref_result_shapes = {'inp1:1': np.array([3, 1]), 'inp2:3': ((3, np.iinfo(np.int64).max), (0, 2), (5, 10), -1, -1), 'inp3:4': np.array([5])}
|
||||
ref_result_data_types = {'inp2:3': np.int32, 'inp3:4': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
assert np.array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
self.assertEqual(input_list, ["inp1:1","inp2:3","inp3:4"])
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user