Add missed attributes for Parameter layer (#5676)
* Add shape and element_type to ir_v10_attrs structure * Add comment * Update add_input_op_input_port_with_data() and add_input_op_output_port_with_data() functions with new GraphAPI * Remove incorrectly added test * Apply comments * Update missed tests * Add more checks to tests Co-authored-by: achetver <anton.chetverikov@.intel.com>
This commit is contained in:
parent
98ac1d04c3
commit
9d3780648f
@ -805,10 +805,16 @@ def add_input_op_input_port_without_data(graph: Graph, node_id: str, input_op, e
|
||||
|
||||
|
||||
def add_input_op_input_port_with_data(graph: Graph, node_id: str, input_op, edge_attrs: dict):
|
||||
input_data_node = input_op.create_node_with_data()
|
||||
input_node = input_data_node.in_node()
|
||||
graph.add_edge(input_data_node.id, node_id, **edge_attrs)
|
||||
update_ie_fields(graph.node[input_node.id])
|
||||
assert graph.stage == 'middle', 'add_input_op_input_port_with_data() function can be used only for graph after ' \
|
||||
'shape inference!'
|
||||
input_node = input_op.create_node(edge_attrs=edge_attrs)
|
||||
node = Node(graph, node_id)
|
||||
|
||||
out_port = input_node.out_port(edge_attrs['out'])
|
||||
out_port.connect(node.in_port(edge_attrs['in']))
|
||||
out_port.data.set_shape(input_node.soft_get('shape', None))
|
||||
input_data_node = input_node.out_node(0)
|
||||
|
||||
log.debug('Input: {} for node {}'.format(input_node.id, node_id))
|
||||
log.debug("Add edge from {} to {}".format(input_node.id, input_data_node.id))
|
||||
log.debug("Add edge from {} to {}".format(input_data_node.id, node_id))
|
||||
@ -831,11 +837,12 @@ def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op,
|
||||
|
||||
def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, port: int):
|
||||
# we assume that after op always data node
|
||||
assert graph.stage == 'middle', 'add_input_op_input_port_with_data() function can be used only for graph after ' \
|
||||
'shape inference!'
|
||||
data_node = Node(graph, node_id).out_node(port)
|
||||
assert data_node.has_valid('kind') and data_node.kind == 'data'
|
||||
input_op.create_node_with_data(data_nodes=data_node)
|
||||
input_node = data_node.in_node()
|
||||
update_ie_fields(graph.node[input_node.id])
|
||||
input_node = input_op.create_node()
|
||||
Node(graph, node_id).out_port(port).get_connection().set_source(input_node.out_port(0))
|
||||
log.debug('Input: {} for node {}'.format(input_node.id, node_id))
|
||||
log.debug("Add edge from {} to {}".format(input_node.id, node_id))
|
||||
return input_node.id
|
||||
|
@ -116,6 +116,7 @@ class TestAddInputOp(unittest.TestCase):
|
||||
|
||||
def test_in_port_with_data(self):
|
||||
graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges)
|
||||
graph.stage = 'middle'
|
||||
new_input_shape = np.array([1, 2, 3, 4])
|
||||
graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:],
|
||||
new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Parameter',
|
||||
@ -155,6 +156,7 @@ class TestAddInputOp(unittest.TestCase):
|
||||
new_nodes_with_attrs=[('input_data', {'kind': 'data', 'shape': None, 'value': None})],
|
||||
new_edges_with_attrs=[('op_node', 'input_data', {'out': 1, 'in': 0}),
|
||||
('input_data', 'future_input', {'in': 0, 'out': 0})])
|
||||
graph.stage = 'middle'
|
||||
new_input_shape = np.array([1, 2, 3, 4])
|
||||
graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:],
|
||||
new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Parameter',
|
||||
@ -252,6 +254,7 @@ class TestInputAddition(unittest.TestCase):
|
||||
('output_data', 'op_output')
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
graph.stage = 'middle'
|
||||
add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False)
|
||||
new_input = list(graph.in_edges(list(graph.in_edges('conv_1'))[0][0]))[0][0]
|
||||
new_input_data = list(graph.in_edges('conv_1'))[0][0]
|
||||
@ -395,16 +398,17 @@ class TestInputAddition(unittest.TestCase):
|
||||
'old_input': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'inp_data' : {'kind': 'data', 'shape': shape + 1},
|
||||
'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
|
||||
'conv_data': {'kind': 'data', 'shape': shape, 'value': None},
|
||||
'conv_data': {'kind': 'data', 'shape': shape, 'value': None, 'data_attr': 'data_attr_value'},
|
||||
'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
|
||||
}
|
||||
edges = [
|
||||
('old_input', 'inp_data'),
|
||||
('inp_data', 'conv_1'),
|
||||
('conv_1', 'conv_data'),
|
||||
('conv_data', 'relu_1'),
|
||||
('conv_data', 'relu_1', {'edge_attr': 'edge_value'}),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
graph.stage = 'middle'
|
||||
add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False)
|
||||
|
||||
graph_ref = build_graph(nodes_attrs={'new_input': {'kind': 'op', 'op': 'Parameter', 'shape': shape},
|
||||
@ -412,7 +416,7 @@ class TestInputAddition(unittest.TestCase):
|
||||
edges=[('old_input', 'inp_data'),
|
||||
('inp_data', 'conv_1'),
|
||||
('new_input', 'conv_data'),
|
||||
('conv_data', 'relu_1'),
|
||||
('conv_data', 'relu_1', {'edge_attr': 'edge_value'}),
|
||||
],)
|
||||
# Check that new input is added right (with right ports !)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1')
|
||||
@ -427,8 +431,9 @@ class TestInputAddition(unittest.TestCase):
|
||||
new_input = 'conv_1/placeholder_out_port_0'
|
||||
|
||||
self.assertTrue(graph.node[new_input]['is_input'])
|
||||
self.assertTrue((new_input, 'conv_data') in graph.edges())
|
||||
self.assertTrue(('conv_1', 'conv_data') not in graph.edges())
|
||||
|
||||
self.assertTrue(Node(graph, 'relu_1').in_node(0)['data_attr'] == 'data_attr_value')
|
||||
self.assertTrue(Node(graph, 'relu_1').in_edge(0)['edge_attr'] == 'edge_value')
|
||||
|
||||
|
||||
@generator
|
||||
|
Loading…
Reference in New Issue
Block a user