From 9d3780648fb6c91f3f2fc7d5d55450521389a58e Mon Sep 17 00:00:00 2001 From: Anton Chetverikov Date: Fri, 28 May 2021 10:22:06 +0300 Subject: [PATCH] Add missed attributes for Parameter layer (#5676) * Add shape and element_type to ir_v10_attrs structure * Add comment * Update add_input_op_input_port_with_data() and add_input_op_output_port_with_data() functions with new GraphAPI * Remove incorrectly added test * Apply comments * Update missed tests * Add more checks to tests Co-authored-by: achetver --- model-optimizer/mo/front/extractor.py | 21 ++++++++++++------- .../unit_tests/mo/front/extractor_test.py | 15 ++++++++----- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/model-optimizer/mo/front/extractor.py b/model-optimizer/mo/front/extractor.py index acb5003d266..a0376a6328c 100644 --- a/model-optimizer/mo/front/extractor.py +++ b/model-optimizer/mo/front/extractor.py @@ -805,10 +805,16 @@ def add_input_op_input_port_without_data(graph: Graph, node_id: str, input_op, e def add_input_op_input_port_with_data(graph: Graph, node_id: str, input_op, edge_attrs: dict): - input_data_node = input_op.create_node_with_data() - input_node = input_data_node.in_node() - graph.add_edge(input_data_node.id, node_id, **edge_attrs) - update_ie_fields(graph.node[input_node.id]) + assert graph.stage == 'middle', 'add_input_op_input_port_with_data() function can be used only for graph after ' \ + 'shape inference!' + input_node = input_op.create_node(edge_attrs=edge_attrs) + node = Node(graph, node_id) + + out_port = input_node.out_port(edge_attrs['out']) + out_port.connect(node.in_port(edge_attrs['in'])) + out_port.data.set_shape(input_node.soft_get('shape', None)) + input_data_node = input_node.out_node(0) + log.debug('Input: {} for node {}'.format(input_node.id, node_id)) log.debug("Add edge from {} to {}".format(input_node.id, input_data_node.id)) log.debug("Add edge from {} to {}".format(input_data_node.id, node_id)) @@ -831,11 +837,12 @@ def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op, def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, port: int): # we assume that after op always data node + assert graph.stage == 'middle', 'add_input_op_input_port_with_data() function can be used only for graph after ' \ + 'shape inference!' data_node = Node(graph, node_id).out_node(port) assert data_node.has_valid('kind') and data_node.kind == 'data' - input_op.create_node_with_data(data_nodes=data_node) - input_node = data_node.in_node() - update_ie_fields(graph.node[input_node.id]) + input_node = input_op.create_node() + Node(graph, node_id).out_port(port).get_connection().set_source(input_node.out_port(0)) log.debug('Input: {} for node {}'.format(input_node.id, node_id)) log.debug("Add edge from {} to {}".format(input_node.id, node_id)) return input_node.id diff --git a/model-optimizer/unit_tests/mo/front/extractor_test.py b/model-optimizer/unit_tests/mo/front/extractor_test.py index c7fdbc5079c..27e70db8178 100644 --- a/model-optimizer/unit_tests/mo/front/extractor_test.py +++ b/model-optimizer/unit_tests/mo/front/extractor_test.py @@ -116,6 +116,7 @@ class TestAddInputOp(unittest.TestCase): def test_in_port_with_data(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges) + graph.stage = 'middle' new_input_shape = np.array([1, 2, 3, 4]) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:], new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Parameter', @@ -155,6 +156,7 @@ class TestAddInputOp(unittest.TestCase): new_nodes_with_attrs=[('input_data', {'kind': 'data', 'shape': None, 'value': None})], new_edges_with_attrs=[('op_node', 'input_data', {'out': 1, 'in': 0}), ('input_data', 'future_input', {'in': 0, 'out': 0})]) + graph.stage = 'middle' new_input_shape = np.array([1, 2, 3, 4]) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:], new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Parameter', @@ -252,6 +254,7 @@ class TestInputAddition(unittest.TestCase): ('output_data', 'op_output') ] graph = build_graph(nodes, edges) + graph.stage = 'middle' add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False) new_input = list(graph.in_edges(list(graph.in_edges('conv_1'))[0][0]))[0][0] new_input_data = list(graph.in_edges('conv_1'))[0][0] @@ -395,16 +398,17 @@ class TestInputAddition(unittest.TestCase): 'old_input': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, 'inp_data' : {'kind': 'data', 'shape': shape + 1}, 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'}, - 'conv_data': {'kind': 'data', 'shape': shape, 'value': None}, + 'conv_data': {'kind': 'data', 'shape': shape, 'value': None, 'data_attr': 'data_attr_value'}, 'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'}, } edges = [ ('old_input', 'inp_data'), ('inp_data', 'conv_1'), ('conv_1', 'conv_data'), - ('conv_data', 'relu_1'), + ('conv_data', 'relu_1', {'edge_attr': 'edge_value'}), ] graph = build_graph(nodes, edges) + graph.stage = 'middle' add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False) graph_ref = build_graph(nodes_attrs={'new_input': {'kind': 'op', 'op': 'Parameter', 'shape': shape}, @@ -412,7 +416,7 @@ class TestInputAddition(unittest.TestCase): edges=[('old_input', 'inp_data'), ('inp_data', 'conv_1'), ('new_input', 'conv_data'), - ('conv_data', 'relu_1'), + ('conv_data', 'relu_1', {'edge_attr': 'edge_value'}), ],) # Check that new input is added right (with right ports !) (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1') @@ -427,8 +431,9 @@ class TestInputAddition(unittest.TestCase): new_input = 'conv_1/placeholder_out_port_0' self.assertTrue(graph.node[new_input]['is_input']) - self.assertTrue((new_input, 'conv_data') in graph.edges()) - self.assertTrue(('conv_1', 'conv_data') not in graph.edges()) + + self.assertTrue(Node(graph, 'relu_1').in_node(0)['data_attr'] == 'data_attr_value') + self.assertTrue(Node(graph, 'relu_1').in_edge(0)['edge_attr'] == 'edge_value') @generator