[MO] EmptyTensorList transform (#9361)

* Initial change for new transformations

* Update patterns

* Update unsupported operation replacement

* Add input/output normalization passes call

* Update logic

* Refactor output concatination transform

* Update re_numerate_input_ports and shape infer functions for Loop

* Update comments

* Add back edge removing to output concatenation transformations

* Update comment

* Remove redundant normallization call

* Update supported layers list

* Use routine in check

* Add transformation to rub_before list
This commit is contained in:
Anton Chetverikov
2022-01-28 20:53:16 +03:00
committed by GitHub
parent eaa0a68fdb
commit 97a78d0059
4 changed files with 148 additions and 6 deletions

View File

@@ -199,6 +199,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
| DepthwiseConv2dNative| | | DepthwiseConv2dNative| |
| Einsum | Supported only with equation that does not contain repeated labels within a subscript | | Einsum | Supported only with equation that does not contain repeated labels within a subscript |
| Elu | | | Elu | |
| EmptyTensorList | Supported only when it is part of a sub-graph of the special form |
| Enter | Supported only when it is fused to the TensorIterator layer | | Enter | Supported only when it is fused to the TensorIterator layer |
| Equal | | | Equal | |
| Erf | | | Erf | |
@@ -336,6 +337,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
| TensorArraySizeV3 | Supported only when it is fused to the TensorIterator layer | | TensorArraySizeV3 | Supported only when it is fused to the TensorIterator layer |
| TensorArrayV3 | Supported only when it is fused to the TensorIterator layer | | TensorArrayV3 | Supported only when it is fused to the TensorIterator layer |
| TensorArrayWriteV3 | Supported only when it is fused to the TensorIterator layer | | TensorArrayWriteV3 | Supported only when it is fused to the TensorIterator layer |
| TensorListPushBack | Supported only when it is part of a sub-graph of the special form |
| Tile | | | Tile | |
| TopkV2 | | | TopkV2 | |
| Transpose | | | Transpose | |

View File

@@ -5,15 +5,15 @@ import logging as log
import numpy as np import numpy as np
from openvino.tools.mo.front.tf.WhileNormalize import WhileNormalize
from openvino.tools.mo.ops.loop import Loop
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.front.tf.WhileNormalize import WhileNormalize
from openvino.tools.mo.front.tf.custom_subgraph_call import skip_nodes_by_condition from openvino.tools.mo.front.tf.custom_subgraph_call import skip_nodes_by_condition
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes
from openvino.tools.mo.middle.pattern_match import find_pattern_matches, inverse_dict from openvino.tools.mo.middle.pattern_match import find_pattern_matches, inverse_dict
from openvino.tools.mo.ops.const import Const from openvino.tools.mo.ops.const import Const
from openvino.tools.mo.ops.loop import Loop
from openvino.tools.mo.ops.squeeze import Squeeze from openvino.tools.mo.ops.squeeze import Squeeze
from openvino.tools.mo.ops.unsqueeze import Unsqueeze from openvino.tools.mo.ops.unsqueeze import Unsqueeze
@@ -65,6 +65,22 @@ class MapFNInputSlicing(FrontReplacementSubgraph):
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})] ('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
) )
@staticmethod
def get_body_pattern_without_identity():
return dict(
nodes=[('tensor_list', dict(op='Parameter')),
('current_iteration', dict(op='Parameter')),
('slicing', dict(op='TensorListGetItem')),
('const_increment', dict(op='Const')),
('increment_iteration', dict(op='Add')),
('increment_iteration_result', dict(op='Result'))],
edges=[('tensor_list', 'slicing', {'in': 0}),
('current_iteration', 'slicing', {'in': 1}),
('const_increment', 'increment_iteration', {'in': 1}),
('current_iteration', 'increment_iteration', {'in': 0}),
('increment_iteration', 'increment_iteration_result', {'in': 0})]
)
@staticmethod @staticmethod
def transform_map_fn_input_slicing(external_match: dict, internal_match: dict): def transform_map_fn_input_slicing(external_match: dict, internal_match: dict):
""" """
@@ -102,7 +118,9 @@ class MapFNInputSlicing(FrontReplacementSubgraph):
loop_name = loop_node.soft_get('name', loop_node.id) loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body'] body_graph = loop_node['body']
body_pattern = MapFNInputSlicing.get_body_pattern() body_pattern = MapFNInputSlicing.get_body_pattern()
body_pattern_without_identity = MapFNInputSlicing.get_body_pattern_without_identity()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
internal_matches += find_subgraph_match_to_pattern(body_graph, body_pattern_without_identity)
for internal_match in internal_matches: for internal_match in internal_matches:
# check if TensorListGetItem from the body graph is connected with TensorListFromTensor # check if TensorListGetItem from the body graph is connected with TensorListFromTensor
@@ -162,6 +180,26 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})] ('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
) )
@staticmethod
def get_body_pattern_without_identity():
return dict(
nodes=[('container', dict(op='Parameter')),
('current_iteration', dict(op='Parameter')),
('const_increment', dict(op='Const')),
('increment_iteration', dict(op='Add')),
('increment_iteration_result', dict(op='Result')),
('concatenation', dict(op='TensorListSetItem')),
('concatenation_result', dict(op='Result'))
],
edges=[('const_increment', 'increment_iteration', {'in': 1}),
('current_iteration', 'increment_iteration', {'in': 0}),
('container', 'concatenation', {'in': 0}),
('current_iteration', 'concatenation', {'in': 1}),
('concatenation', 'concatenation_result', {'in': 0}),
('increment_iteration', 'increment_iteration_result', {'in': 0})
]
)
@staticmethod @staticmethod
def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict): def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict):
""" """
@@ -208,17 +246,24 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
const_true = Const(body_graph, {'value': mo_array(True, dtype=np.bool)}).create_node() const_true = Const(body_graph, {'value': mo_array(True, dtype=np.bool)}).create_node()
exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0)) exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0))
# remove back edge
for record in loop_node.back_edges:
if 'from_layer' in record and record['from_layer'] == list_result_node_layer_id:
loop_node.back_edges.remove(record)
def find_and_replace_pattern(self, graph: Graph): def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'): for loop_node in graph.get_op_nodes(op='Loop'):
loop_name = loop_node.soft_get('name', loop_node.id) loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body'] body_graph = loop_node['body']
body_pattern = MapFNOutputConcatenation.get_body_pattern() body_pattern = MapFNOutputConcatenation.get_body_pattern()
body_pattern_without_identity = MapFNOutputConcatenation.get_body_pattern_without_identity()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
internal_matches += find_subgraph_match_to_pattern(body_graph, body_pattern_without_identity)
for internal_match in internal_matches: for internal_match in internal_matches:
# check if TensorListReserve from the main graph is connected with Parameter node from the body graph # check if TensorListReserve from the main graph is connected with Parameter node from the body graph
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation # that is assigned for storing intermediate output results of While Loop. If yes, the transformation
# detects intermediate outputs concatentation by this port and can use Loop axis attribute # detects intermediate outputs concatenation by this port and can use Loop axis attribute
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node, reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['container'].internal_layer_id) internal_match['container'].internal_layer_id)
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
@@ -255,3 +300,97 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
internal_match['increment_iteration_result'].internal_layer_id, internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id): internal_match['current_iteration'].internal_layer_id):
MapFNOutputConcatenation.transform_map_fn_output_concatenation(external_match, internal_match) MapFNOutputConcatenation.transform_map_fn_output_concatenation(external_match, internal_match)
class TensorListOutputConcatenation(FrontReplacementSubgraph):
"""
The transformation handles inputs slicing in While loop. It avoids TensorListPushBack, and EmptyTensorList
operations and replaces the original sub-graph by adding axis attribute in Loop node for concatenation of
intermediate output results.
"""
enabled = True
def run_before(self):
return [WhileNormalize]
@staticmethod
def get_body_pattern():
return dict(
nodes=[('container', dict(op='Parameter')),
('concatenation', dict(op='TensorListPushBack')),
('concatenation_result', dict(op='Result'))
],
edges=[
('container', 'concatenation', {'in': 0}),
('concatenation', 'concatenation_result', {'in': 0}),
]
)
@staticmethod
def transform_tensor_list_output_concatenation(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for output concatenation
:param internal_match: a match used for handling a part of the body graph responsible for output concatenation
"""
loop_node = external_match['while']
empty_tensor_list_node = external_match['reserve']
body_graph = loop_node['body']
tensor_list_push_back_node = internal_match['concatenation']
tensor_list_push_back_node_name = tensor_list_push_back_node.soft_get('name', tensor_list_push_back_node.id)
list_result_node = internal_match['concatenation_result']
# replace TensorListPushBack with Unsqueeze and use axis attribute for corresponding Result node
# to concatenate results from different iterations
unsqueeze_list_element = create_op_with_const_inputs(body_graph, Unsqueeze, {1: int64_array(0)},
{'name': tensor_list_push_back_node_name +
'/TensorListPushBackUnsqueeze'})
tensor_list_push_back_node.in_port(1).get_connection().set_destination(unsqueeze_list_element.in_port(0))
tensor_list_push_back_node.out_port(0).get_connection().set_source(unsqueeze_list_element.out_port(0))
rename_nodes([(tensor_list_push_back_node, tensor_list_push_back_node_name + '/AbandonedName'),
(unsqueeze_list_element, tensor_list_push_back_node_name)])
list_result_node_layer_id = list_result_node.internal_layer_id
Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id,
'axis', 0)
# disconnect EmptyTensorList node because it is no longer needed for Loop
empty_tensor_list_node.out_port(0).disconnect()
loop_node.in_port(1).disconnect()
empty_tensor_list_node.in_port(1).get_source().connect(loop_node.in_port(1))
# remove back edge
for record in loop_node.back_edges:
if 'from_layer' in record and record['from_layer'] == list_result_node_layer_id:
loop_node.back_edges.remove(record)
def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'):
loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body']
body_pattern = TensorListOutputConcatenation.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
for internal_match in internal_matches:
# check if EmptyTensorList from the main graph is connected with Parameter node from the body graph
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation
# detects intermediate outputs concatenation by this port and can use Loop axis attribute
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['container'].internal_layer_id)
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
reserve_node[0].op == 'EmptyTensorList') else None
if reserve_node is None:
log.info("A sub-graph around the loop node {} does not match "
"TensorFlow 2 EmptyTensorList->TensorListPushBack pattern for intermediate "
"outputs concatenation".format(loop_name))
continue
external_match = {'while': loop_node,
'reserve': reserve_node}
# check that back edges connect Parameter node (or container with intermediate output results)
# and concatenation result produced by TensorListPushBack node
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
internal_match['container'].internal_layer_id):
TensorListOutputConcatenation.transform_tensor_list_output_concatenation(external_match,
internal_match)

View File

@@ -28,7 +28,8 @@ from openvino.tools.mo.front.TransposeOrderNormalizer import TransposeOrderNorma
from openvino.tools.mo.front.split_normalizer import SqueezeAxis from openvino.tools.mo.front.split_normalizer import SqueezeAxis
from openvino.tools.mo.front.tf.CropAndResizeReplacement import CropAndResizeReplacement from openvino.tools.mo.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
from openvino.tools.mo.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize from openvino.tools.mo.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
from openvino.tools.mo.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation from openvino.tools.mo.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation,\
TensorListOutputConcatenation
from openvino.tools.mo.front.tf.TFSliceToSlice import TFSliceToSliceReplacer from openvino.tools.mo.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
from openvino.tools.mo.front.tf.pad_tf_to_pad import PadTFToPad from openvino.tools.mo.front.tf.pad_tf_to_pad import PadTFToPad
from openvino.tools.mo.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \ from openvino.tools.mo.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
@@ -594,7 +595,7 @@ class ObjectDetectionAPITransformationsFinish(FrontReplacementPattern):
def run_before(self): def run_before(self):
return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, TFSliceToSliceReplacer, MapFNInputSlicing, return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, TFSliceToSliceReplacer, MapFNInputSlicing,
MapFNOutputConcatenation, CropAndResizeReplacement] MapFNOutputConcatenation, TensorListOutputConcatenation, CropAndResizeReplacement]
def find_and_replace_pattern(self, graph: Graph): def find_and_replace_pattern(self, graph: Graph):
pass pass

View File

@@ -446,7 +446,7 @@ class Loop(TensorIterator):
max_port_id = sorted(loop_node.in_ports().keys())[-1] max_port_id = sorted(loop_node.in_ports().keys())[-1]
new_port_id = 0 new_port_id = 0
for port_id in range(max_port_id + 1): for port_id in range(max_port_id + 1):
if port_id in loop_node.in_ports(): if loop_node.is_in_port_connected(port_id):
if port_id != new_port_id: if port_id != new_port_id:
re_number_input_port(loop_node, port_id, new_port_id) re_number_input_port(loop_node, port_id, new_port_id)
new_port_id += 1 new_port_id += 1