Provide keep_output_port attribute to add_opoutput function (#6117)

* Implement way to provide keep_output_port attribute to add_opoutput function

* Update tests

* Update comment

* Fake commit to pictures merge problem

* Change default value

* Add type

* Revert "Fake commit to pictures merge problem"

This reverts commit 41850765e0.
This commit is contained in:
Anton Chetverikov 2021-07-05 10:34:44 +03:00 committed by GitHub
parent a65828c172
commit 8a31e8aafb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 7 deletions

View File

@ -154,7 +154,7 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
size_splits.append(l - prev_r)
shape[split_channel_dim] = l - prev_r
data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
add_opoutput(graph, data_node.id, 0, False)
add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
final_data_nodes_list.append(data_node)
prev_r = r
@ -167,7 +167,7 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
size_splits.append(input_shape[split_channel_dim] - prev_r)
data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
add_opoutput(graph, data_node.id, 0, False)
add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
final_data_nodes_list.append(data_node)
for node in out_nodes:

View File

@ -1032,21 +1032,24 @@ def dict_includes(big: dict, sub_dict: dict, skip_attr_names=[]):
)
def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True):
def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep_output_port: bool = False):
"""
Creates and connects Result node to node_name port. Cuts existing port if requested.
:param graph: graph to operate with
:param node_name: name of existing node in the graph that we want to add Result to
:param port: output port of node to connect Result to
:param cut: determines way of operating with edge specified by node_name and port
:param keep_output_port: special attribute determines if this operation is saved in IR or not
"""
# we import it here because Op imports add_attrs_props and update_ie_fields from this file
from mo.ops.result import Result
node = Node(graph, node_name)
if cut and len(node.out_edges()) != 0:
opoutput_node = Result(graph).create_node_on_port(node, port, {'name': node_name + '/sink_port_' + str(port)})
opoutput_node = Result(graph).create_node_on_port(node, port, {'name': node_name + '/sink_port_' + str(port),
'keep_output_port': keep_output_port})
else:
opoutput_node = Result(graph).create_node([(node, port)], {'name': node_name + '/sink_port_' + str(port)})
opoutput_node = Result(graph).create_node([(node, port)], {'name': node_name + '/sink_port_' + str(port),
'keep_output_port': keep_output_port})
opoutput_node.in_edge()['data_attrs'] = ['fw_tensor_debug_info']
log.debug('Sink: {} for node {}'.format(opoutput_node.id, node_name))

View File

@ -60,8 +60,8 @@ nodes_attributes = {
'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'op_output': {'kind': 'op', 'op': 'Result'},
'op_output_1': {'kind': 'op', 'op': 'Result'},
'op_output_2': {'kind': 'op', 'op': 'Result'},
'op_output_1': {'kind': 'op', 'op': 'Result', 'keep_output_port': True},
'op_output_2': {'kind': 'op', 'op': 'Result', 'keep_output_port': True},
# Squeeze layers
'sslice_1/Squeeze_shrink': {'type': None, 'value': None, 'kind': 'op', 'op': 'Squeeze'},