fix check for data nodes in emitter.py (#14802)
* fix checking if there are data nodes for operations with several outputs * added unit-test * added ports explicitly, removed redundant lambda * typo in port
This commit is contained in:
parent
745ef24e19
commit
1d5fa360d4
@ -512,8 +512,8 @@ def serialize_network(graph, net_element, unsupported):
|
||||
check_and_add_result_name(node.soft_get('name'), ordered_results)
|
||||
continue
|
||||
|
||||
# Here output data node count is checked. Each port cannot have more than one data node.
|
||||
assert len(node.out_nodes()) == 1, "Incorrect graph. Non-Result node with name {} " \
|
||||
# Here output data node count is checked. Output Op nodes must have at least one data node
|
||||
assert len(node.out_nodes()) >= 1, "Incorrect graph. Non-Result node with name {} " \
|
||||
"has no output data node.".format(output_name)
|
||||
|
||||
# After port renumbering port/connection API is not applicable, and output port numbering
|
||||
|
@ -3,16 +3,25 @@
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import defusedxml.ElementTree as ET
|
||||
import numpy as np
|
||||
from defusedxml import defuse_stdlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.back.ie_ir_ver_2.emitter import soft_get, xml_shape, serialize_runtime_info
|
||||
from openvino.tools.mo.back.ie_ir_ver_2.emitter import soft_get, xml_shape, serialize_runtime_info, serialize_network, \
|
||||
port_renumber
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.middle.passes.infer import partial_infer, type_infer
|
||||
from openvino.tools.mo.ops.gather import Gather
|
||||
from openvino.tools.mo.ops.parameter import Parameter
|
||||
from openvino.tools.mo.ops.pooling import Pooling
|
||||
from openvino.tools.mo.ops.result import Result
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
from openvino.tools.mo.utils.runtime_info import RTInfo, OldAPIMapOrder, OldAPIMapElementType
|
||||
from unit_tests.utils.graph import build_graph, result, regular_op
|
||||
from openvino.tools.mo.utils.unsupported_ops import UnsupportedOps
|
||||
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
|
||||
shaped_parameter, build_graph, regular_op
|
||||
|
||||
# defuse_stdlib provide patched version of xml.etree.ElementTree which allows to use objects from xml.etree.ElementTree
|
||||
# in a safe manner without including unsafe xml.etree.ElementTree
|
||||
@ -117,3 +126,143 @@ class TestSerializeRTInfo(unittest.TestCase):
|
||||
self.assertTrue("value=\"0,3,1,2\"" in serialize_res)
|
||||
self.assertTrue(serialize_res.startswith("b'<net><rt_info>"))
|
||||
self.assertTrue(serialize_res.endswith("</rt_info></net>'"))
|
||||
|
||||
|
||||
class TestSerialize(unittest.TestCase):
|
||||
@staticmethod
|
||||
def build_graph_with_gather():
|
||||
nodes = {
|
||||
**shaped_parameter('data', int64_array([3, 3]), {'data_type': np.float32, 'type': Parameter.op}),
|
||||
**shaped_parameter('indices', int64_array([1, 2]), {'data_type': np.float32, 'type': Parameter.op}),
|
||||
**valued_const_with_data('axis', int64_array(1)),
|
||||
**regular_op_with_empty_data('gather', {'op': 'Gather', 'batch_dims': 0, 'infer': Gather.infer,
|
||||
'type': Gather.op}),
|
||||
**result('res'),
|
||||
}
|
||||
|
||||
edges = [
|
||||
*connect('data', '0:gather'),
|
||||
*connect('indices', '1:gather'),
|
||||
*connect('axis', '2:gather'),
|
||||
*connect('gather', 'res'),
|
||||
]
|
||||
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
data_node = Node(graph, 'data')
|
||||
Parameter.update_node_stat(data_node, {})
|
||||
indices_node = Node(graph, 'indices')
|
||||
Parameter.update_node_stat(indices_node, {})
|
||||
|
||||
gather_node = Node(graph, 'gather')
|
||||
Gather.update_node_stat(gather_node, {})
|
||||
|
||||
res_node = Node(graph, 'res')
|
||||
Result.update_node_stat(res_node, {})
|
||||
|
||||
partial_infer(graph)
|
||||
type_infer(graph)
|
||||
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
def build_graph_with_maxpool():
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
'input': {'kind': 'op', 'op': 'Parameter', 'name': 'node', 'infer': Parameter.infer,
|
||||
'shape': [1, 3, 10, 10]},
|
||||
'input_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
|
||||
'pool': {'kind': 'op', 'type': 'MaxPool', 'infer': Pooling.infer,
|
||||
'window': np.array([1, 1, 2, 2]), 'stride': np.array([1, 1, 2, 2]),
|
||||
'pad': np.array([[0, 0], [0, 0], [0, 0], [1, 1]]),
|
||||
'pad_spatial_shape': np.array([[0, 0], [1, 1]]),
|
||||
'pool_method': 'max', 'exclude_pad': False, 'global_pool': False,
|
||||
'output_spatial_shape': None, 'output_shape': None,
|
||||
'kernel_spatial': np.array([2, 2]), 'spatial_dims': np.array([2, 3]),
|
||||
'channel_dims': np.array([1]), 'batch_dims': np.array([0]),
|
||||
'pooling_convention': 'full', 'dilation': np.array([1, 1, 2, 2]),
|
||||
'auto_pad': 'valid'},
|
||||
|
||||
'pool_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
'pool_data_added': {'kind': 'data', 'value': None, 'shape': None},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'result_added': {'kind': 'op', 'op': 'Result'}
|
||||
},
|
||||
edges=[
|
||||
('input', 'input_data'),
|
||||
('input_data', 'pool'),
|
||||
('pool', 'pool_data', {'out': 0}),
|
||||
('pool_data', 'result'),
|
||||
('pool', 'pool_data_added', {'out': 1}),
|
||||
('pool_data_added', 'result_added')
|
||||
]
|
||||
)
|
||||
|
||||
input_node = Node(graph, 'input')
|
||||
Parameter.update_node_stat(input_node, {})
|
||||
|
||||
pool_node = Node(graph, 'pool')
|
||||
Pooling.update_node_stat(pool_node, {'pool_method': 'max'})
|
||||
|
||||
result_node = Node(graph, 'result')
|
||||
Result.update_node_stat(result_node, {})
|
||||
result_added_node = Node(graph, 'result_added')
|
||||
Result.update_node_stat(result_added_node, {})
|
||||
|
||||
partial_infer(graph)
|
||||
type_infer(graph)
|
||||
return graph
|
||||
|
||||
def test_gather(self):
|
||||
graph = self.build_graph_with_gather()
|
||||
|
||||
net = Element('net')
|
||||
graph.outputs_order = ['gather']
|
||||
unsupported = UnsupportedOps(graph)
|
||||
port_renumber(graph)
|
||||
|
||||
serialize_network(graph, net, unsupported)
|
||||
xml_string = str(tostring(net))
|
||||
self.assertTrue("type=\"Parameter\"" in xml_string)
|
||||
self.assertTrue("type=\"Result\"" in xml_string)
|
||||
self.assertTrue("type=\"Gather\"" in xml_string)
|
||||
|
||||
def test_maxpool(self):
|
||||
graph = self.build_graph_with_maxpool()
|
||||
|
||||
net = Element('net')
|
||||
graph.outputs_order = ['pool']
|
||||
unsupported = UnsupportedOps(graph)
|
||||
port_renumber(graph)
|
||||
serialize_network(graph, net, unsupported)
|
||||
xml_string = str(tostring(net))
|
||||
self.assertTrue("type=\"Parameter\"" in xml_string)
|
||||
self.assertTrue("type=\"Result\"" in xml_string)
|
||||
self.assertTrue("type=\"Pooling\"" in xml_string)
|
||||
|
||||
def test_maxpool_raises(self):
|
||||
graph = self.build_graph_with_maxpool()
|
||||
|
||||
pool_node = Node(graph, 'pool')
|
||||
result_node = Node(graph, 'result')
|
||||
result_added_node = Node(graph, 'result_added')
|
||||
pool_out_1 = Node(graph, 'pool_data')
|
||||
pool_out_2 = Node(graph, 'pool_data_added')
|
||||
|
||||
# when operation does not have output data nodes Exception should be raised
|
||||
graph.remove_edge(pool_node.id, pool_out_1.id)
|
||||
graph.remove_edge(pool_node.id, pool_out_2.id)
|
||||
graph.remove_edge(pool_out_1.id, result_node.id)
|
||||
graph.remove_edge(pool_out_2.id, result_added_node.id)
|
||||
|
||||
graph.remove_node(result_node.id)
|
||||
graph.remove_node(result_added_node.id)
|
||||
|
||||
net = Element('net')
|
||||
graph.outputs_order = ['pool']
|
||||
unsupported = UnsupportedOps(graph)
|
||||
port_renumber(graph)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Incorrect graph. Non-Result node.*"):
|
||||
serialize_network(graph, net, unsupported)
|
||||
|
Loading…
Reference in New Issue
Block a user