[MO] EmptyTensorList transform (#9361)
* Initial change for new transformations * Update patterns * Update unsupported operation replacement * Add input/output normalization passes call * Update logic * Refactor output concatination transform * Update re_numerate_input_ports and shape infer functions for Loop * Update comments * Add back edge removing to output concatenation transformations * Update comment * Remove redundant normallization call * Update supported layers list * Use routine in check * Add transformation to rub_before list
This commit is contained in:
committed by
GitHub
parent
eaa0a68fdb
commit
97a78d0059
@@ -5,15 +5,15 @@ import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.tf.WhileNormalize import WhileNormalize
|
||||
from openvino.tools.mo.ops.loop import Loop
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, mo_array
|
||||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from openvino.tools.mo.front.tf.WhileNormalize import WhileNormalize
|
||||
from openvino.tools.mo.front.tf.custom_subgraph_call import skip_nodes_by_condition
|
||||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from openvino.tools.mo.graph.graph import Graph, Node, rename_nodes
|
||||
from openvino.tools.mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
||||
from openvino.tools.mo.ops.const import Const
|
||||
from openvino.tools.mo.ops.loop import Loop
|
||||
from openvino.tools.mo.ops.squeeze import Squeeze
|
||||
from openvino.tools.mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
@@ -65,6 +65,22 @@ class MapFNInputSlicing(FrontReplacementSubgraph):
|
||||
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_body_pattern_without_identity():
|
||||
return dict(
|
||||
nodes=[('tensor_list', dict(op='Parameter')),
|
||||
('current_iteration', dict(op='Parameter')),
|
||||
('slicing', dict(op='TensorListGetItem')),
|
||||
('const_increment', dict(op='Const')),
|
||||
('increment_iteration', dict(op='Add')),
|
||||
('increment_iteration_result', dict(op='Result'))],
|
||||
edges=[('tensor_list', 'slicing', {'in': 0}),
|
||||
('current_iteration', 'slicing', {'in': 1}),
|
||||
('const_increment', 'increment_iteration', {'in': 1}),
|
||||
('current_iteration', 'increment_iteration', {'in': 0}),
|
||||
('increment_iteration', 'increment_iteration_result', {'in': 0})]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transform_map_fn_input_slicing(external_match: dict, internal_match: dict):
|
||||
"""
|
||||
@@ -102,7 +118,9 @@ class MapFNInputSlicing(FrontReplacementSubgraph):
|
||||
loop_name = loop_node.soft_get('name', loop_node.id)
|
||||
body_graph = loop_node['body']
|
||||
body_pattern = MapFNInputSlicing.get_body_pattern()
|
||||
body_pattern_without_identity = MapFNInputSlicing.get_body_pattern_without_identity()
|
||||
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
|
||||
internal_matches += find_subgraph_match_to_pattern(body_graph, body_pattern_without_identity)
|
||||
|
||||
for internal_match in internal_matches:
|
||||
# check if TensorListGetItem from the body graph is connected with TensorListFromTensor
|
||||
@@ -162,6 +180,26 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
|
||||
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_body_pattern_without_identity():
|
||||
return dict(
|
||||
nodes=[('container', dict(op='Parameter')),
|
||||
('current_iteration', dict(op='Parameter')),
|
||||
('const_increment', dict(op='Const')),
|
||||
('increment_iteration', dict(op='Add')),
|
||||
('increment_iteration_result', dict(op='Result')),
|
||||
('concatenation', dict(op='TensorListSetItem')),
|
||||
('concatenation_result', dict(op='Result'))
|
||||
],
|
||||
edges=[('const_increment', 'increment_iteration', {'in': 1}),
|
||||
('current_iteration', 'increment_iteration', {'in': 0}),
|
||||
('container', 'concatenation', {'in': 0}),
|
||||
('current_iteration', 'concatenation', {'in': 1}),
|
||||
('concatenation', 'concatenation_result', {'in': 0}),
|
||||
('increment_iteration', 'increment_iteration_result', {'in': 0})
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict):
|
||||
"""
|
||||
@@ -208,17 +246,24 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
|
||||
const_true = Const(body_graph, {'value': mo_array(True, dtype=np.bool)}).create_node()
|
||||
exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0))
|
||||
|
||||
# remove back edge
|
||||
for record in loop_node.back_edges:
|
||||
if 'from_layer' in record and record['from_layer'] == list_result_node_layer_id:
|
||||
loop_node.back_edges.remove(record)
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for loop_node in graph.get_op_nodes(op='Loop'):
|
||||
loop_name = loop_node.soft_get('name', loop_node.id)
|
||||
body_graph = loop_node['body']
|
||||
body_pattern = MapFNOutputConcatenation.get_body_pattern()
|
||||
body_pattern_without_identity = MapFNOutputConcatenation.get_body_pattern_without_identity()
|
||||
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
|
||||
internal_matches += find_subgraph_match_to_pattern(body_graph, body_pattern_without_identity)
|
||||
|
||||
for internal_match in internal_matches:
|
||||
# check if TensorListReserve from the main graph is connected with Parameter node from the body graph
|
||||
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation
|
||||
# detects intermediate outputs concatentation by this port and can use Loop axis attribute
|
||||
# detects intermediate outputs concatenation by this port and can use Loop axis attribute
|
||||
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
|
||||
internal_match['container'].internal_layer_id)
|
||||
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
|
||||
@@ -255,3 +300,97 @@ class MapFNOutputConcatenation(FrontReplacementSubgraph):
|
||||
internal_match['increment_iteration_result'].internal_layer_id,
|
||||
internal_match['current_iteration'].internal_layer_id):
|
||||
MapFNOutputConcatenation.transform_map_fn_output_concatenation(external_match, internal_match)
|
||||
|
||||
|
||||
class TensorListOutputConcatenation(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation handles inputs slicing in While loop. It avoids TensorListPushBack, and EmptyTensorList
|
||||
operations and replaces the original sub-graph by adding axis attribute in Loop node for concatenation of
|
||||
intermediate output results.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_before(self):
|
||||
return [WhileNormalize]
|
||||
|
||||
@staticmethod
|
||||
def get_body_pattern():
|
||||
return dict(
|
||||
nodes=[('container', dict(op='Parameter')),
|
||||
('concatenation', dict(op='TensorListPushBack')),
|
||||
('concatenation_result', dict(op='Result'))
|
||||
],
|
||||
edges=[
|
||||
('container', 'concatenation', {'in': 0}),
|
||||
('concatenation', 'concatenation_result', {'in': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transform_tensor_list_output_concatenation(external_match: dict, internal_match: dict):
|
||||
"""
|
||||
Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
|
||||
:param external_match: a match used for handling a part of the main graph responsible for output concatenation
|
||||
:param internal_match: a match used for handling a part of the body graph responsible for output concatenation
|
||||
"""
|
||||
loop_node = external_match['while']
|
||||
empty_tensor_list_node = external_match['reserve']
|
||||
body_graph = loop_node['body']
|
||||
|
||||
tensor_list_push_back_node = internal_match['concatenation']
|
||||
tensor_list_push_back_node_name = tensor_list_push_back_node.soft_get('name', tensor_list_push_back_node.id)
|
||||
list_result_node = internal_match['concatenation_result']
|
||||
|
||||
# replace TensorListPushBack with Unsqueeze and use axis attribute for corresponding Result node
|
||||
# to concatenate results from different iterations
|
||||
unsqueeze_list_element = create_op_with_const_inputs(body_graph, Unsqueeze, {1: int64_array(0)},
|
||||
{'name': tensor_list_push_back_node_name +
|
||||
'/TensorListPushBackUnsqueeze'})
|
||||
tensor_list_push_back_node.in_port(1).get_connection().set_destination(unsqueeze_list_element.in_port(0))
|
||||
tensor_list_push_back_node.out_port(0).get_connection().set_source(unsqueeze_list_element.out_port(0))
|
||||
rename_nodes([(tensor_list_push_back_node, tensor_list_push_back_node_name + '/AbandonedName'),
|
||||
(unsqueeze_list_element, tensor_list_push_back_node_name)])
|
||||
list_result_node_layer_id = list_result_node.internal_layer_id
|
||||
Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id,
|
||||
'axis', 0)
|
||||
|
||||
# disconnect EmptyTensorList node because it is no longer needed for Loop
|
||||
empty_tensor_list_node.out_port(0).disconnect()
|
||||
|
||||
loop_node.in_port(1).disconnect()
|
||||
empty_tensor_list_node.in_port(1).get_source().connect(loop_node.in_port(1))
|
||||
|
||||
# remove back edge
|
||||
for record in loop_node.back_edges:
|
||||
if 'from_layer' in record and record['from_layer'] == list_result_node_layer_id:
|
||||
loop_node.back_edges.remove(record)
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for loop_node in graph.get_op_nodes(op='Loop'):
|
||||
loop_name = loop_node.soft_get('name', loop_node.id)
|
||||
body_graph = loop_node['body']
|
||||
body_pattern = TensorListOutputConcatenation.get_body_pattern()
|
||||
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
|
||||
|
||||
for internal_match in internal_matches:
|
||||
# check if EmptyTensorList from the main graph is connected with Parameter node from the body graph
|
||||
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation
|
||||
# detects intermediate outputs concatenation by this port and can use Loop axis attribute
|
||||
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
|
||||
internal_match['container'].internal_layer_id)
|
||||
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
|
||||
reserve_node[0].op == 'EmptyTensorList') else None
|
||||
if reserve_node is None:
|
||||
log.info("A sub-graph around the loop node {} does not match "
|
||||
"TensorFlow 2 EmptyTensorList->TensorListPushBack pattern for intermediate "
|
||||
"outputs concatenation".format(loop_name))
|
||||
continue
|
||||
|
||||
external_match = {'while': loop_node,
|
||||
'reserve': reserve_node}
|
||||
# check that back edges connect Parameter node (or container with intermediate output results)
|
||||
# and concatenation result produced by TensorListPushBack node
|
||||
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
|
||||
internal_match['container'].internal_layer_id):
|
||||
TensorListOutputConcatenation.transform_tensor_list_output_concatenation(external_match,
|
||||
internal_match)
|
||||
|
||||
@@ -28,7 +28,8 @@ from openvino.tools.mo.front.TransposeOrderNormalizer import TransposeOrderNorma
|
||||
from openvino.tools.mo.front.split_normalizer import SqueezeAxis
|
||||
from openvino.tools.mo.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
|
||||
from openvino.tools.mo.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
|
||||
from openvino.tools.mo.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation
|
||||
from openvino.tools.mo.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation,\
|
||||
TensorListOutputConcatenation
|
||||
from openvino.tools.mo.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
|
||||
from openvino.tools.mo.front.tf.pad_tf_to_pad import PadTFToPad
|
||||
from openvino.tools.mo.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
|
||||
@@ -594,7 +595,7 @@ class ObjectDetectionAPITransformationsFinish(FrontReplacementPattern):
|
||||
|
||||
def run_before(self):
|
||||
return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, TFSliceToSliceReplacer, MapFNInputSlicing,
|
||||
MapFNOutputConcatenation, CropAndResizeReplacement]
|
||||
MapFNOutputConcatenation, TensorListOutputConcatenation, CropAndResizeReplacement]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
pass
|
||||
|
||||
@@ -446,7 +446,7 @@ class Loop(TensorIterator):
|
||||
max_port_id = sorted(loop_node.in_ports().keys())[-1]
|
||||
new_port_id = 0
|
||||
for port_id in range(max_port_id + 1):
|
||||
if port_id in loop_node.in_ports():
|
||||
if loop_node.is_in_port_connected(port_id):
|
||||
if port_id != new_port_id:
|
||||
re_number_input_port(loop_node, port_id, new_port_id)
|
||||
new_port_id += 1
|
||||
|
||||
Reference in New Issue
Block a user