fix check for data nodes in emitter.py (#14802)
* fix checking if there are data nodes for operations with several outputs * added unit-test * added ports explicitly, removed redundant lambda * typo in port
This commit is contained in:
parent
745ef24e19
commit
1d5fa360d4
@ -512,8 +512,8 @@ def serialize_network(graph, net_element, unsupported):
|
|||||||
check_and_add_result_name(node.soft_get('name'), ordered_results)
|
check_and_add_result_name(node.soft_get('name'), ordered_results)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Here output data node count is checked. Each port cannot have more than one data node.
|
# Here output data node count is checked. Output Op nodes must have at least one data node
|
||||||
assert len(node.out_nodes()) == 1, "Incorrect graph. Non-Result node with name {} " \
|
assert len(node.out_nodes()) >= 1, "Incorrect graph. Non-Result node with name {} " \
|
||||||
"has no output data node.".format(output_name)
|
"has no output data node.".format(output_name)
|
||||||
|
|
||||||
# After port renumbering port/connection API is not applicable, and output port numbering
|
# After port renumbering port/connection API is not applicable, and output port numbering
|
||||||
|
@ -3,16 +3,25 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import defusedxml.ElementTree as ET
|
import defusedxml.ElementTree as ET
|
||||||
|
import numpy as np
|
||||||
from defusedxml import defuse_stdlib
|
from defusedxml import defuse_stdlib
|
||||||
|
|
||||||
import numpy as np
|
from openvino.tools.mo.back.ie_ir_ver_2.emitter import soft_get, xml_shape, serialize_runtime_info, serialize_network, \
|
||||||
|
port_renumber
|
||||||
from openvino.tools.mo.back.ie_ir_ver_2.emitter import soft_get, xml_shape, serialize_runtime_info
|
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||||
from openvino.tools.mo.graph.graph import Node
|
from openvino.tools.mo.graph.graph import Node
|
||||||
|
from openvino.tools.mo.middle.passes.infer import partial_infer, type_infer
|
||||||
|
from openvino.tools.mo.ops.gather import Gather
|
||||||
|
from openvino.tools.mo.ops.parameter import Parameter
|
||||||
|
from openvino.tools.mo.ops.pooling import Pooling
|
||||||
|
from openvino.tools.mo.ops.result import Result
|
||||||
from openvino.tools.mo.utils.error import Error
|
from openvino.tools.mo.utils.error import Error
|
||||||
from openvino.tools.mo.utils.runtime_info import RTInfo, OldAPIMapOrder, OldAPIMapElementType
|
from openvino.tools.mo.utils.runtime_info import RTInfo, OldAPIMapOrder, OldAPIMapElementType
|
||||||
from unit_tests.utils.graph import build_graph, result, regular_op
|
from openvino.tools.mo.utils.unsupported_ops import UnsupportedOps
|
||||||
|
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
|
||||||
|
shaped_parameter, build_graph, regular_op
|
||||||
|
|
||||||
# defuse_stdlib provide patched version of xml.etree.ElementTree which allows to use objects from xml.etree.ElementTree
|
# defuse_stdlib provide patched version of xml.etree.ElementTree which allows to use objects from xml.etree.ElementTree
|
||||||
# in a safe manner without including unsafe xml.etree.ElementTree
|
# in a safe manner without including unsafe xml.etree.ElementTree
|
||||||
@ -117,3 +126,143 @@ class TestSerializeRTInfo(unittest.TestCase):
|
|||||||
self.assertTrue("value=\"0,3,1,2\"" in serialize_res)
|
self.assertTrue("value=\"0,3,1,2\"" in serialize_res)
|
||||||
self.assertTrue(serialize_res.startswith("b'<net><rt_info>"))
|
self.assertTrue(serialize_res.startswith("b'<net><rt_info>"))
|
||||||
self.assertTrue(serialize_res.endswith("</rt_info></net>'"))
|
self.assertTrue(serialize_res.endswith("</rt_info></net>'"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerialize(unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def build_graph_with_gather():
|
||||||
|
nodes = {
|
||||||
|
**shaped_parameter('data', int64_array([3, 3]), {'data_type': np.float32, 'type': Parameter.op}),
|
||||||
|
**shaped_parameter('indices', int64_array([1, 2]), {'data_type': np.float32, 'type': Parameter.op}),
|
||||||
|
**valued_const_with_data('axis', int64_array(1)),
|
||||||
|
**regular_op_with_empty_data('gather', {'op': 'Gather', 'batch_dims': 0, 'infer': Gather.infer,
|
||||||
|
'type': Gather.op}),
|
||||||
|
**result('res'),
|
||||||
|
}
|
||||||
|
|
||||||
|
edges = [
|
||||||
|
*connect('data', '0:gather'),
|
||||||
|
*connect('indices', '1:gather'),
|
||||||
|
*connect('axis', '2:gather'),
|
||||||
|
*connect('gather', 'res'),
|
||||||
|
]
|
||||||
|
|
||||||
|
graph = build_graph(nodes, edges)
|
||||||
|
|
||||||
|
data_node = Node(graph, 'data')
|
||||||
|
Parameter.update_node_stat(data_node, {})
|
||||||
|
indices_node = Node(graph, 'indices')
|
||||||
|
Parameter.update_node_stat(indices_node, {})
|
||||||
|
|
||||||
|
gather_node = Node(graph, 'gather')
|
||||||
|
Gather.update_node_stat(gather_node, {})
|
||||||
|
|
||||||
|
res_node = Node(graph, 'res')
|
||||||
|
Result.update_node_stat(res_node, {})
|
||||||
|
|
||||||
|
partial_infer(graph)
|
||||||
|
type_infer(graph)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_graph_with_maxpool():
|
||||||
|
graph = build_graph(
|
||||||
|
nodes_attrs={
|
||||||
|
'input': {'kind': 'op', 'op': 'Parameter', 'name': 'node', 'infer': Parameter.infer,
|
||||||
|
'shape': [1, 3, 10, 10]},
|
||||||
|
'input_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||||
|
|
||||||
|
'pool': {'kind': 'op', 'type': 'MaxPool', 'infer': Pooling.infer,
|
||||||
|
'window': np.array([1, 1, 2, 2]), 'stride': np.array([1, 1, 2, 2]),
|
||||||
|
'pad': np.array([[0, 0], [0, 0], [0, 0], [1, 1]]),
|
||||||
|
'pad_spatial_shape': np.array([[0, 0], [1, 1]]),
|
||||||
|
'pool_method': 'max', 'exclude_pad': False, 'global_pool': False,
|
||||||
|
'output_spatial_shape': None, 'output_shape': None,
|
||||||
|
'kernel_spatial': np.array([2, 2]), 'spatial_dims': np.array([2, 3]),
|
||||||
|
'channel_dims': np.array([1]), 'batch_dims': np.array([0]),
|
||||||
|
'pooling_convention': 'full', 'dilation': np.array([1, 1, 2, 2]),
|
||||||
|
'auto_pad': 'valid'},
|
||||||
|
|
||||||
|
'pool_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||||
|
'pool_data_added': {'kind': 'data', 'value': None, 'shape': None},
|
||||||
|
'result': {'kind': 'op', 'op': 'Result'},
|
||||||
|
'result_added': {'kind': 'op', 'op': 'Result'}
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
('input', 'input_data'),
|
||||||
|
('input_data', 'pool'),
|
||||||
|
('pool', 'pool_data', {'out': 0}),
|
||||||
|
('pool_data', 'result'),
|
||||||
|
('pool', 'pool_data_added', {'out': 1}),
|
||||||
|
('pool_data_added', 'result_added')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
input_node = Node(graph, 'input')
|
||||||
|
Parameter.update_node_stat(input_node, {})
|
||||||
|
|
||||||
|
pool_node = Node(graph, 'pool')
|
||||||
|
Pooling.update_node_stat(pool_node, {'pool_method': 'max'})
|
||||||
|
|
||||||
|
result_node = Node(graph, 'result')
|
||||||
|
Result.update_node_stat(result_node, {})
|
||||||
|
result_added_node = Node(graph, 'result_added')
|
||||||
|
Result.update_node_stat(result_added_node, {})
|
||||||
|
|
||||||
|
partial_infer(graph)
|
||||||
|
type_infer(graph)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def test_gather(self):
|
||||||
|
graph = self.build_graph_with_gather()
|
||||||
|
|
||||||
|
net = Element('net')
|
||||||
|
graph.outputs_order = ['gather']
|
||||||
|
unsupported = UnsupportedOps(graph)
|
||||||
|
port_renumber(graph)
|
||||||
|
|
||||||
|
serialize_network(graph, net, unsupported)
|
||||||
|
xml_string = str(tostring(net))
|
||||||
|
self.assertTrue("type=\"Parameter\"" in xml_string)
|
||||||
|
self.assertTrue("type=\"Result\"" in xml_string)
|
||||||
|
self.assertTrue("type=\"Gather\"" in xml_string)
|
||||||
|
|
||||||
|
def test_maxpool(self):
|
||||||
|
graph = self.build_graph_with_maxpool()
|
||||||
|
|
||||||
|
net = Element('net')
|
||||||
|
graph.outputs_order = ['pool']
|
||||||
|
unsupported = UnsupportedOps(graph)
|
||||||
|
port_renumber(graph)
|
||||||
|
serialize_network(graph, net, unsupported)
|
||||||
|
xml_string = str(tostring(net))
|
||||||
|
self.assertTrue("type=\"Parameter\"" in xml_string)
|
||||||
|
self.assertTrue("type=\"Result\"" in xml_string)
|
||||||
|
self.assertTrue("type=\"Pooling\"" in xml_string)
|
||||||
|
|
||||||
|
def test_maxpool_raises(self):
|
||||||
|
graph = self.build_graph_with_maxpool()
|
||||||
|
|
||||||
|
pool_node = Node(graph, 'pool')
|
||||||
|
result_node = Node(graph, 'result')
|
||||||
|
result_added_node = Node(graph, 'result_added')
|
||||||
|
pool_out_1 = Node(graph, 'pool_data')
|
||||||
|
pool_out_2 = Node(graph, 'pool_data_added')
|
||||||
|
|
||||||
|
# when operation does not have output data nodes Exception should be raised
|
||||||
|
graph.remove_edge(pool_node.id, pool_out_1.id)
|
||||||
|
graph.remove_edge(pool_node.id, pool_out_2.id)
|
||||||
|
graph.remove_edge(pool_out_1.id, result_node.id)
|
||||||
|
graph.remove_edge(pool_out_2.id, result_added_node.id)
|
||||||
|
|
||||||
|
graph.remove_node(result_node.id)
|
||||||
|
graph.remove_node(result_added_node.id)
|
||||||
|
|
||||||
|
net = Element('net')
|
||||||
|
graph.outputs_order = ['pool']
|
||||||
|
unsupported = UnsupportedOps(graph)
|
||||||
|
port_renumber(graph)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(AssertionError, "Incorrect graph. Non-Result node.*"):
|
||||||
|
serialize_network(graph, net, unsupported)
|
||||||
|
Loading…
Reference in New Issue
Block a user