[MO] EmptyTensorList transform (#9361)

* Initial change for new transformations

* Update patterns

* Update unsupported operation replacement

* Add input/output normalization passes call

* Update logic

* Refactor output concatination transform

* Update re_numerate_input_ports and shape infer functions for Loop

* Update comments

* Add back edge removing to output concatenation transformations

* Update comment

* Remove redundant normallization call

* Update supported layers list

* Use routine in check

* Add transformation to rub_before list
This commit is contained in:
Anton Chetverikov
2022-01-28 20:53:16 +03:00
committed by GitHub
parent eaa0a68fdb
commit 97a78d0059
4 changed files with 148 additions and 6 deletions

View File

@@ -5,15 +5,15 @@ import logging as log
import numpy as np
from openvino.tools.mo.front.tf.WhileNormalize import WhileNormalize
from openvino.tools.mo.ops.loop import Loop
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.front.tf.WhileNormalize import WhileNormalize
from openvino.tools.mo.front.tf.custom_subgraph_call import skip_nodes_by_condition
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes
from openvino.tools.mo.middle.pattern_match import find_pattern_matches, inverse_dict
from openvino.tools.mo.ops.const import Const
from openvino.tools.mo.ops.loop import Loop
from openvino.tools.mo.ops.squeeze import Squeeze
from openvino.tools.mo.ops.unsqueeze import Unsqueeze
@@ -65,6 +65,22 @@ class MapFNInputSlicing(FrontReplacementSubgraph):
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
)
@staticmethod
def get_body_pattern_without_identity():
return dict(
nodes=[('tensor_list', dict(op='Parameter')),
('current_iteration', dict(op='Parameter')),
('slicing', dict(op='TensorListGetItem')),
('const_increment', dict(op='Const')),
('increment_iteration', dict(op='Add')),
('increment_iteration_result', dict(op='Result'))],
edges=[('tensor_list', 'slicing', {'in': 0}),
('current_iteration', 'slicing', {'in': 1}),
('const_increment', 'increment_iteration', {'in': 1}),
('current_iteration', 'increment_iteration', {'in': 0}),
('increment_iteration', 'increment_iteration_result', {'in': 0})]
)
@staticmethod
def transform_map_fn_input_slicing(external_match: dict, internal_match: dict):
"""
@@ -102,7 +118,9 @@ class MapFNInputSlicing(FrontReplacementSubgraph):
loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body']
body_pattern = MapFNInputSlicing.get_body_pattern()
body_pattern_without_identity = MapFNInputSlicing.get_body_pattern_without_identity()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
internal_matches += find_subgraph_match_to_pattern(body_graph, body_pattern_without_identity)
for internal_match in internal_matches:
# check if TensorListGetItem from the body graph is connected with TensorListFromTensor
@@ -162,6 +180,26 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
)
@staticmethod
def get_body_pattern_without_identity():
return dict(
nodes=[('container', dict(op='Parameter')),
('current_iteration', dict(op='Parameter')),
('const_increment', dict(op='Const')),
('increment_iteration', dict(op='Add')),
('increment_iteration_result', dict(op='Result')),
('concatenation', dict(op='TensorListSetItem')),
('concatenation_result', dict(op='Result'))
],
edges=[('const_increment', 'increment_iteration', {'in': 1}),
('current_iteration', 'increment_iteration', {'in': 0}),
('container', 'concatenation', {'in': 0}),
('current_iteration', 'concatenation', {'in': 1}),
('concatenation', 'concatenation_result', {'in': 0}),
('increment_iteration', 'increment_iteration_result', {'in': 0})
]
)
@staticmethod
def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict):
"""
@@ -208,17 +246,24 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
const_true = Const(body_graph, {'value': mo_array(True, dtype=np.bool)}).create_node()
exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0))
# remove back edge
for record in loop_node.back_edges:
if 'from_layer' in record and record['from_layer'] == list_result_node_layer_id:
loop_node.back_edges.remove(record)
def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'):
loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body']
body_pattern = MapFNOutputConcatenation.get_body_pattern()
body_pattern_without_identity = MapFNOutputConcatenation.get_body_pattern_without_identity()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
internal_matches += find_subgraph_match_to_pattern(body_graph, body_pattern_without_identity)
for internal_match in internal_matches:
# check if TensorListReserve from the main graph is connected with Parameter node from the body graph
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation
# detects intermediate outputs concatentation by this port and can use Loop axis attribute
# detects intermediate outputs concatenation by this port and can use Loop axis attribute
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['container'].internal_layer_id)
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
@@ -255,3 +300,97 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id):
MapFNOutputConcatenation.transform_map_fn_output_concatenation(external_match, internal_match)
class TensorListOutputConcatenation(FrontReplacementSubgraph):
"""
The transformation handles inputs slicing in While loop. It avoids TensorListPushBack, and EmptyTensorList
operations and replaces the original sub-graph by adding axis attribute in Loop node for concatenation of
intermediate output results.
"""
enabled = True
def run_before(self):
return [WhileNormalize]
@staticmethod
def get_body_pattern():
return dict(
nodes=[('container', dict(op='Parameter')),
('concatenation', dict(op='TensorListPushBack')),
('concatenation_result', dict(op='Result'))
],
edges=[
('container', 'concatenation', {'in': 0}),
('concatenation', 'concatenation_result', {'in': 0}),
]
)
@staticmethod
def transform_tensor_list_output_concatenation(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for output concatenation
:param internal_match: a match used for handling a part of the body graph responsible for output concatenation
"""
loop_node = external_match['while']
empty_tensor_list_node = external_match['reserve']
body_graph = loop_node['body']
tensor_list_push_back_node = internal_match['concatenation']
tensor_list_push_back_node_name = tensor_list_push_back_node.soft_get('name', tensor_list_push_back_node.id)
list_result_node = internal_match['concatenation_result']
# replace TensorListPushBack with Unsqueeze and use axis attribute for corresponding Result node
# to concatenate results from different iterations
unsqueeze_list_element = create_op_with_const_inputs(body_graph, Unsqueeze, {1: int64_array(0)},
{'name': tensor_list_push_back_node_name +
'/TensorListPushBackUnsqueeze'})
tensor_list_push_back_node.in_port(1).get_connection().set_destination(unsqueeze_list_element.in_port(0))
tensor_list_push_back_node.out_port(0).get_connection().set_source(unsqueeze_list_element.out_port(0))
rename_nodes([(tensor_list_push_back_node, tensor_list_push_back_node_name + '/AbandonedName'),
(unsqueeze_list_element, tensor_list_push_back_node_name)])
list_result_node_layer_id = list_result_node.internal_layer_id
Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id,
'axis', 0)
# disconnect EmptyTensorList node because it is no longer needed for Loop
empty_tensor_list_node.out_port(0).disconnect()
loop_node.in_port(1).disconnect()
empty_tensor_list_node.in_port(1).get_source().connect(loop_node.in_port(1))
# remove back edge
for record in loop_node.back_edges:
if 'from_layer' in record and record['from_layer'] == list_result_node_layer_id:
loop_node.back_edges.remove(record)
def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'):
loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body']
body_pattern = TensorListOutputConcatenation.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
for internal_match in internal_matches:
# check if EmptyTensorList from the main graph is connected with Parameter node from the body graph
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation
# detects intermediate outputs concatenation by this port and can use Loop axis attribute
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['container'].internal_layer_id)
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
reserve_node[0].op == 'EmptyTensorList') else None
if reserve_node is None:
log.info("A sub-graph around the loop node {} does not match "
"TensorFlow 2 EmptyTensorList->TensorListPushBack pattern for intermediate "
"outputs concatenation".format(loop_name))
continue
external_match = {'while': loop_node,
'reserve': reserve_node}
# check that back edges connect Parameter node (or container with intermediate output results)
# and concatenation result produced by TensorListPushBack node
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
internal_match['container'].internal_layer_id):
TensorListOutputConcatenation.transform_tensor_list_output_concatenation(external_match,
internal_match)

View File

@@ -28,7 +28,8 @@ from openvino.tools.mo.front.TransposeOrderNormalizer import TransposeOrderNorma
from openvino.tools.mo.front.split_normalizer import SqueezeAxis
from openvino.tools.mo.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
from openvino.tools.mo.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
from openvino.tools.mo.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation
from openvino.tools.mo.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation,\
TensorListOutputConcatenation
from openvino.tools.mo.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
from openvino.tools.mo.front.tf.pad_tf_to_pad import PadTFToPad
from openvino.tools.mo.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
@@ -594,7 +595,7 @@ class ObjectDetectionAPITransformationsFinish(FrontReplacementPattern):
def run_before(self):
return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, TFSliceToSliceReplacer, MapFNInputSlicing,
MapFNOutputConcatenation, CropAndResizeReplacement]
MapFNOutputConcatenation, TensorListOutputConcatenation, CropAndResizeReplacement]
def find_and_replace_pattern(self, graph: Graph):
pass

View File

@@ -446,7 +446,7 @@ class Loop(TensorIterator):
max_port_id = sorted(loop_node.in_ports().keys())[-1]
new_port_id = 0
for port_id in range(max_port_id + 1):
if port_id in loop_node.in_ports():
if loop_node.is_in_port_connected(port_id):
if port_id != new_port_id:
re_number_input_port(loop_node, port_id, new_port_id)
new_port_id += 1