[MO] Implement TensorFlow 2 While and Keras RNN support in MO (#3573)
* [MO] Implement TensorFlow 2 While support in MO Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Add extractors for both While and StatelessWhile and do minor changes Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Improve update_body_graph function and manage graph names properly Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix a map for original name of parameters from body and cond Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Implement draft version of support of TF2 Keras RNN Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Implement Keras LSTM and GRU support in MO Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Improve code for Keras RNN support Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Finalize implementation of TF2 Keras RNN support in MO Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Apply the first part of the comments after review #1 Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Avoid use of explicit values of port indices in the transformation Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Finalize code after the first-round review Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Apply comments after the second-round review Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
61ccde700f
commit
bacb8420f0
@ -404,6 +404,7 @@ extensions/front/tf/identity_ext.py
|
||||
extensions/front/tf/identityN_to_identity.py
|
||||
extensions/front/tf/InterpolateTransposes.py
|
||||
extensions/front/tf/IteratorGetNext_ext.py
|
||||
extensions/front/tf/KerasRNNTransformation.py
|
||||
extensions/front/tf/log_softmax_ext.py
|
||||
extensions/front/tf/LookupTableInsert_ext.py
|
||||
extensions/front/tf/LoopCond_ext.py
|
||||
@ -483,6 +484,8 @@ extensions/front/tf/UnpackPackReverseInputChannels.py
|
||||
extensions/front/tf/variable_ext.py
|
||||
extensions/front/tf/variables_values_freezing.py
|
||||
extensions/front/tf/WhereDecomposition.py
|
||||
extensions/front/tf/while_ext.py
|
||||
extensions/front/tf/WhileNormalize.py
|
||||
extensions/front/tf/yolo_v1.json
|
||||
extensions/front/tf/yolo_v1_tiny.json
|
||||
extensions/front/tf/yolo_v2.json
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,7 +15,6 @@
|
||||
"""
|
||||
import logging as log
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -125,6 +124,11 @@ class RemoveConstToResult(BackReplacementPattern):
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
# TODO: remove this transformation once all plugins support constant value network.
|
||||
# Do not run recursively since Const->Result sub-graph can be encountered in a body graph of Loop node
|
||||
# and this sub-graph is needed to avoid dynamism created by Loop node
|
||||
# in case using axis in output port map
|
||||
run_not_recursively = True
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -22,6 +22,9 @@ from mo.graph.graph import Graph
|
||||
|
||||
class StandaloneConstEraser(FrontReplacementSubgraph):
|
||||
enabled = True
|
||||
# TODO: remove this transformation once all plugins support constant value network.
|
||||
# Now it avoids to be run recursively since Const->Result sub-graph can be encountered in a body graph of Loop node
|
||||
run_not_recursively = True
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
|
268
model-optimizer/extensions/front/tf/KerasRNNTransformation.py
Normal file
268
model-optimizer/extensions/front/tf/KerasRNNTransformation.py
Normal file
@ -0,0 +1,268 @@
|
||||
"""
|
||||
Copyright (C) 2017-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.tf.WhileNormalize import WhileNormalize
|
||||
from extensions.ops.loop import Loop
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node, rename_nodes
|
||||
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.squeeze import Squeeze
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
|
||||
def compute_input_port_idx(req_node: Node, loop_node: Node):
|
||||
"""
|
||||
Computes input port index by which requested node is passed to Loop node
|
||||
:param req_node: a node for which to find input port index is requested
|
||||
:param loop_node: a node that can receive input data from requested node by some input port
|
||||
:return: input port index
|
||||
"""
|
||||
for destination in req_node.out_port(0).get_destinations():
|
||||
if loop_node.id == destination.node.id:
|
||||
return destination.idx
|
||||
return None
|
||||
|
||||
|
||||
def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict):
|
||||
"""
|
||||
Finds sub-graph matches corresponding pattern in graph
|
||||
:param graph: a graph where to search for matched sub-graph
|
||||
:param body_pattern: a pattern
|
||||
:return: a list of sub-graph matches
|
||||
"""
|
||||
matches = []
|
||||
for match in find_pattern_matches(graph, **body_pattern):
|
||||
match = inverse_dict(match)
|
||||
for k in match:
|
||||
match[k] = Node(graph, match[k])
|
||||
matches.append(match)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
class KerasRNNInputSlicing(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation detects TensorFlow 2 pattern that corresponds to subsequent slicing of input.
|
||||
It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original sub-graph
|
||||
by adding axis attribute for corresponding input port of Loop node.
|
||||
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_before(self):
|
||||
return [WhileNormalize]
|
||||
|
||||
@staticmethod
|
||||
def pattern(**kwargs):
|
||||
return dict(
|
||||
nodes=[('unstack', dict(op='TensorListFromTensor')),
|
||||
('while', dict(op='Loop'))],
|
||||
edges=[('unstack', 'while')]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_body_pattern():
|
||||
return dict(
|
||||
nodes=[('tensor_list', dict(op='Parameter')),
|
||||
('current_iteration', dict(op='Parameter')),
|
||||
('slicing', dict(op='TensorListGetItem')),
|
||||
('const_increment', dict(op='Const')),
|
||||
('increment_iteration', dict(op='Add')),
|
||||
('increment_iteration_identity', dict(op='Identity')),
|
||||
('increment_iteration_result', dict(op='Result'))],
|
||||
edges=[('tensor_list', 'slicing', {'in': 0}),
|
||||
('current_iteration', 'slicing', {'in': 1}),
|
||||
('const_increment', 'increment_iteration', {'in': 1}),
|
||||
('current_iteration', 'increment_iteration', {'in': 0}),
|
||||
('increment_iteration', 'increment_iteration_identity', {'in': 0}),
|
||||
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict):
|
||||
"""
|
||||
Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node
|
||||
:param external_match: a match used for handling a part of the main graph responsible for input slicing
|
||||
:param internal_match: a match used for handling a part of the body graph responsible for input slicing
|
||||
"""
|
||||
loop_node = external_match['while']
|
||||
unstack_node = external_match['unstack']
|
||||
body_graph = loop_node['body']
|
||||
|
||||
tensor_list_get_item_node = internal_match['slicing']
|
||||
unstack_placeholder = internal_match['tensor_list']
|
||||
tensor_list_get_item_node_name = tensor_list_get_item_node.soft_get('name', tensor_list_get_item_node.id)
|
||||
|
||||
# 1. process the body graph to avoid unsupported operations: TensorListGetItem and TensorListSetItem
|
||||
# replace TensorListGetItem with Squeeze node and iterate through slices using axis for input port
|
||||
squeeze_list_element = create_op_with_const_inputs(body_graph, Squeeze, {1: int64_array(0)},
|
||||
{'name': 'TensorListGetItemSqueeze'})
|
||||
tensor_list_get_item_node.in_port(0).get_connection().set_destination(squeeze_list_element.in_port(0))
|
||||
tensor_list_get_item_node.out_port(0).get_connection().set_source(squeeze_list_element.out_port(0))
|
||||
rename_nodes([(tensor_list_get_item_node, tensor_list_get_item_node_name + '/AbandonedName'),
|
||||
(squeeze_list_element, tensor_list_get_item_node_name)])
|
||||
unstack_placeholder_layer_id = unstack_placeholder.internal_layer_id
|
||||
Loop.update_port_map_value_ext(loop_node.input_port_map, 'internal_layer_id', unstack_placeholder_layer_id,
|
||||
'axis', 0)
|
||||
|
||||
# 2. process locality of Loop node in the main graph to avoid unsupported operations:
|
||||
# TensorListFromTensor, TensorListReserve, and TensorListStack
|
||||
# remove TensorListFromTensor and pass a tensor to Loop as is
|
||||
unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source())
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, external_match: dict):
|
||||
loop_node = external_match['while']
|
||||
body_graph = loop_node['body']
|
||||
body_pattern = KerasRNNInputSlicing.get_body_pattern()
|
||||
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
|
||||
|
||||
# a case of multiple matches is not handled since it is not clear how to select corresponding match
|
||||
if len(internal_matches) == 1:
|
||||
internal_match = internal_matches[0]
|
||||
loop_node = external_match['while']
|
||||
unstack_port_idx = compute_input_port_idx(external_match['unstack'], loop_node)
|
||||
# check that back edges connect correct Parameter and Result nodes in the body
|
||||
# check connections between body input ports and external inputs ports of Loop node
|
||||
if Loop.back_edge_exists(loop_node.back_edges,
|
||||
internal_match['increment_iteration_result'].internal_layer_id,
|
||||
internal_match['current_iteration'].internal_layer_id) and \
|
||||
Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx,
|
||||
internal_match['tensor_list'].internal_layer_id):
|
||||
# only if inter-graph match passed it starts to process the sub-graph
|
||||
KerasRNNInputSlicing.transform_keras_rnn_input_slicing(external_match, internal_match)
|
||||
|
||||
|
||||
class KerasRNNOutputConcatenation(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation detects TensorFlow 2 pattern that corresponds to concatenation of intermediate results
|
||||
generated in each iteration of While operation.
|
||||
It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces the original sub-graph
|
||||
by adding axis attribute for corresponding output port of Loop node.
|
||||
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_before(self):
|
||||
return [WhileNormalize]
|
||||
|
||||
@staticmethod
|
||||
def pattern(**kwargs):
|
||||
return dict(
|
||||
nodes=[('reserve', dict(op='TensorListReserve')),
|
||||
('while', dict(op='Loop')),
|
||||
('stack', dict(op='TensorListStack'))],
|
||||
edges=[('reserve', 'while'),
|
||||
('while', 'stack')]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_body_pattern():
|
||||
return dict(
|
||||
nodes=[('container', dict(op='Parameter')),
|
||||
('current_iteration', dict(op='Parameter')),
|
||||
('const_increment', dict(op='Const')),
|
||||
('increment_iteration', dict(op='Add')),
|
||||
('increment_iteration_identity', dict(op='Identity')),
|
||||
('increment_iteration_result', dict(op='Result')),
|
||||
('concatenation', dict(op='TensorListSetItem')),
|
||||
('concatenation_identity', dict(op='Identity')),
|
||||
('concatenation_result', dict(op='Result')),
|
||||
],
|
||||
edges=[('const_increment', 'increment_iteration', {'in': 1}),
|
||||
('current_iteration', 'increment_iteration', {'in': 0}),
|
||||
('container', 'concatenation', {'in': 0}),
|
||||
('current_iteration', 'concatenation', {'in': 1}),
|
||||
('concatenation', 'concatenation_identity', {'in': 0}),
|
||||
('concatenation_identity', 'concatenation_result', {'in': 0}),
|
||||
('increment_iteration', 'increment_iteration_identity', {'in': 0}),
|
||||
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def transform_keras_rnn_output_concatenation(external_match: dict, internal_match: dict):
|
||||
"""
|
||||
Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
|
||||
:param external_match: a match used for handling a part of the main graph responsible for output concatenation
|
||||
:param internal_match: a match used for handling a part of the body graph responsible for output concatenation
|
||||
"""
|
||||
loop_node = external_match['while']
|
||||
stack_node = external_match['stack']
|
||||
list_reserve_node = external_match['reserve']
|
||||
body_graph = loop_node['body']
|
||||
|
||||
tensor_list_set_item_node = internal_match['concatenation']
|
||||
tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get('name', tensor_list_set_item_node.id)
|
||||
list_result_node = internal_match['concatenation_result']
|
||||
|
||||
# replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node
|
||||
# to concatenate results from different iterations
|
||||
unsqueeze_list_element = create_op_with_const_inputs(body_graph, Unsqueeze, {1: int64_array(0)},
|
||||
{'name': 'TensorListSetItemUnsqueeze'})
|
||||
tensor_list_set_item_node.in_port(2).get_connection().set_destination(unsqueeze_list_element.in_port(0))
|
||||
tensor_list_set_item_node.out_port(0).get_connection().set_source(unsqueeze_list_element.out_port(0))
|
||||
rename_nodes([(tensor_list_set_item_node, tensor_list_set_item_node_name + '/AbandonedName'),
|
||||
(unsqueeze_list_element, tensor_list_set_item_node_name)])
|
||||
list_result_node_layer_id = list_result_node.internal_layer_id
|
||||
Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id,
|
||||
'axis', 0)
|
||||
|
||||
# remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated
|
||||
stack_node.out_port(0).get_connection().set_source(stack_node.in_port(0).get_connection().get_source())
|
||||
|
||||
# disconnect ListReserve node because it is no longer needed for Loop
|
||||
list_reserve_node.out_port(0).disconnect()
|
||||
|
||||
# connect a number of iterations with trip count that can be received from the second input of ListReserve
|
||||
# create a constant network with True value for execution_condition so that IE can ignore execution condition
|
||||
# and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism.
|
||||
loop_node.in_port(1).disconnect()
|
||||
list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1))
|
||||
for record in loop_node.output_port_map:
|
||||
if 'purpose' in record and record['purpose'] == 'execution_condition':
|
||||
exec_cond_layer_id = record['internal_layer_id']
|
||||
exec_cond_node = Loop.get_body_node_by_internal_id(loop_node, exec_cond_layer_id)
|
||||
const_true = Const(body_graph, {'value': np.array(True, dtype=np.bool)}).create_node()
|
||||
exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0))
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, external_match: dict):
|
||||
loop_node = external_match['while']
|
||||
body_graph = loop_node['body']
|
||||
body_pattern = KerasRNNOutputConcatenation.get_body_pattern()
|
||||
|
||||
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
|
||||
|
||||
if len(internal_matches) == 1:
|
||||
internal_match = internal_matches[0]
|
||||
reserve_port_idx = compute_input_port_idx(external_match['reserve'], loop_node)
|
||||
stack_port_idx = external_match['stack'].in_port(0).get_source().idx
|
||||
# check that back edges connect correct Parameter and Result nodes in the body
|
||||
# check connections between body input ports and external inputs ports of Loop node
|
||||
# check connections between body output ports and external output ports of Loop node
|
||||
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
|
||||
internal_match['container'].internal_layer_id) and \
|
||||
Loop.back_edge_exists(loop_node.back_edges,
|
||||
internal_match['increment_iteration_result'].internal_layer_id,
|
||||
internal_match['current_iteration'].internal_layer_id) and \
|
||||
Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx,
|
||||
internal_match['container'].internal_layer_id) and \
|
||||
Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx,
|
||||
internal_match['concatenation_result'].internal_layer_id):
|
||||
KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(external_match, internal_match)
|
53
model-optimizer/extensions/front/tf/WhileNormalize.py
Normal file
53
model-optimizer/extensions/front/tf/WhileNormalize.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
Copyright (C) 2017-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.loop import Loop
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class WhileNormalize(FrontReplacementSubgraph):
|
||||
"""
|
||||
Normalize inputs for Loop replacing TensorFlow 2 While operation:
|
||||
1) Remove external input port for current iteration
|
||||
2) Move trip count from port #1 to port #0
|
||||
3) Occupy port #1 for execution condition
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(op='Loop'):
|
||||
self.normalize_loop_node(graph, node)
|
||||
|
||||
@staticmethod
|
||||
def normalize_loop_node(graph: Graph, loop_node: Node):
|
||||
loop_name = loop_node.soft_get('name', loop_node.id)
|
||||
|
||||
# disconnect current iteration from external port #0 and move trip count to this port
|
||||
loop_node.in_port(0).disconnect()
|
||||
loop_node.in_port(1).get_connection().add_destination(loop_node.in_port(0))
|
||||
Loop.update_port_map_value(loop_node.input_port_map, 'external_port_id', 1, 0)
|
||||
|
||||
# connect execution condition port
|
||||
exec_cond_node = Const(graph, {'name': loop_name + '/ExecutionConditionValue',
|
||||
'value': np.array(True, dtype=np.bool)}).create_node()
|
||||
loop_node.in_port(1).get_connection().set_source(exec_cond_node.out_port(0))
|
||||
|
||||
loop_node.body.clean_up()
|
||||
Loop.normalize_input_output_ports(loop_node)
|
207
model-optimizer/extensions/front/tf/while_ext.py
Normal file
207
model-optimizer/extensions/front/tf/while_ext.py
Normal file
@ -0,0 +1,207 @@
|
||||
"""
|
||||
Copyright (C) 2017-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import copy
|
||||
|
||||
from extensions.front.onnx.loop_ext import connect_body_input, connect_body_output
|
||||
from extensions.ops.loop import Loop
|
||||
from extensions.ops.parameter import Parameter
|
||||
from mo.front.common.register_custom_ops import check_for_duplicates
|
||||
from mo.front.extractor import extract_node_attrs, FrontExtractorOp
|
||||
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors
|
||||
from mo.front.tf.extractors.utils import tf_dtype_extractor
|
||||
from mo.graph.graph import add_opoutput, Graph, Node
|
||||
from mo.ops.op import PermuteAttrs
|
||||
|
||||
|
||||
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
|
||||
body_parameter_names: list, body_results: list):
|
||||
"""
|
||||
Updates the loop body graph with a sub-graph (for body or condition functions)
|
||||
:param body_graph: a loop body graph to be updated
|
||||
:param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
|
||||
:param body_parameter_names: a (unchanged) list of parameters in the loop body graph
|
||||
:param body_results: a list of Result nodes that is extended with a list from a sub-graph
|
||||
"""
|
||||
# create a map from a node name in original model to a name in a loop body graph assuming
|
||||
# that names in the original model are unique
|
||||
# initially, the map contains names for parameters that are common for the body and condition graphs
|
||||
map_original_name = {}
|
||||
for idx, pb_node in enumerate(subgraph_proto['input_arg']):
|
||||
map_original_name[pb_node.name] = body_parameter_names[idx]
|
||||
|
||||
# walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
|
||||
for pb_node in subgraph_proto['node_def']:
|
||||
# create an NX node
|
||||
id = body_graph.unique_id(pb_node.name)
|
||||
map_original_name[pb_node.name] = id
|
||||
body_graph.add_node(id, pb=pb_node, kind='op')
|
||||
|
||||
# add incoming edges based on data_nodes_map
|
||||
for dst_port, inp in enumerate(pb_node.input):
|
||||
orig_src_id = inp.split(":")[0]
|
||||
src_id = map_original_name[orig_src_id]
|
||||
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
|
||||
assert (body_graph.has_node(src_id))
|
||||
edge_attrs = {
|
||||
'out': src_port,
|
||||
'in': dst_port,
|
||||
'name': src_id,
|
||||
'fw_tensor_debug_info': [(src_id, src_port)],
|
||||
'in_attrs': ['in', 'name'],
|
||||
'out_attrs': ['out', 'name'],
|
||||
'data_attrs': ['fw_tensor_debug_info']
|
||||
}
|
||||
body_graph.add_edge(src_id, id, **edge_attrs)
|
||||
|
||||
# create Result nodes in the loop body graph
|
||||
for output in subgraph_proto['output_arg']:
|
||||
output_name = subgraph_proto['ret'][output.name]
|
||||
orig_src_id = output_name.split(":")[0]
|
||||
src_id = map_original_name[orig_src_id]
|
||||
src_port = 0 if len(output_name.split(":")) == 1\
|
||||
else int(output_name.split(":")[-1])
|
||||
assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format(
|
||||
src_id)
|
||||
body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
|
||||
|
||||
|
||||
class WhileExtractor(FrontExtractorOp):
|
||||
"""
|
||||
The While operation is a variation of the while_loop primitive from TensorFlow 2 Python API.
|
||||
While can have stateful operations in the body and condition graphs that does not influence on inference so
|
||||
the logic for handling While and StatelessWhile (see below) is the same.
|
||||
"""
|
||||
op = 'While'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, loop_node):
|
||||
Loop.update_node_stat(loop_node, {})
|
||||
loop_name = loop_node.soft_get('name', loop_node.id)
|
||||
|
||||
# check that required body and condition functions exist in the graph library
|
||||
main_graph = loop_node.graph
|
||||
body_graph_name = loop_node.pb.attr['body'].func.name
|
||||
cond_graph_name = loop_node.pb.attr['cond'].func.name
|
||||
assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \
|
||||
'by node with name "{}".'.format(loop_name)
|
||||
library_graph = main_graph.graph['library']
|
||||
|
||||
assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
|
||||
'that is required by node ' \
|
||||
'with name "{}".'.format(body_graph_name, loop_name)
|
||||
body_graph_proto = library_graph[body_graph_name]
|
||||
|
||||
assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
|
||||
'that is required by node ' \
|
||||
'with name "{}".'.format(cond_graph_name, loop_name)
|
||||
cond_graph_proto = library_graph[cond_graph_name]
|
||||
|
||||
body_graph = Graph()
|
||||
# fill the body graph
|
||||
for attr_key in main_graph.graph.keys():
|
||||
if attr_key != 'library':
|
||||
body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key])
|
||||
else:
|
||||
# it is sufficient to have a link to the library
|
||||
body_graph.graph['library'] = main_graph.graph['library']
|
||||
loop_node['body'] = body_graph
|
||||
|
||||
# create Parameter nodes for the body graph
|
||||
body_parameters = []
|
||||
body_parameter_names = []
|
||||
for idx, pb_node in enumerate(body_graph_proto['input_arg']):
|
||||
param_id = body_graph.unique_id(pb_node.name)
|
||||
body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
|
||||
parameter_node = Node(body_graph, pb_node.name)
|
||||
Parameter.update_node_stat(parameter_node,
|
||||
{'data_type': tf_dtype_extractor(pb_node.type),
|
||||
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
|
||||
)
|
||||
body_parameters.append(parameter_node)
|
||||
body_parameter_names.append(param_id)
|
||||
|
||||
# update the loop body graph with the body function graph
|
||||
body_results = []
|
||||
update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results)
|
||||
|
||||
# update the loop body graph with the condition function graph
|
||||
update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results)
|
||||
|
||||
# add 'internal_layer_id' attribute which is a must have attribute for the loop body node
|
||||
for idx, body_node in enumerate(body_graph.get_op_nodes()):
|
||||
body_node['internal_layer_id'] = idx
|
||||
|
||||
body_graph.stage = 'front'
|
||||
|
||||
# Currently,
|
||||
# Loop Inputs Order:
|
||||
# 0 - current iteration
|
||||
# 1 - trip count
|
||||
# 2.. - "loop carried" dependencies variables
|
||||
#
|
||||
# Body Inputs Order:
|
||||
# 0 - current iteration
|
||||
# 1 - trip count
|
||||
# 2.. - "loop carried" dependencies variables
|
||||
#
|
||||
# Body Outputs Order:
|
||||
# 0 - current iteration
|
||||
# 1 - trip count
|
||||
# 2.. - "loop carried" dependencies variables
|
||||
#
|
||||
# Loop Outputs Order:
|
||||
# 0 - current iteration
|
||||
# 1 - trip count
|
||||
# 2.. - "loop carried" dependencies variables
|
||||
#
|
||||
# so inputs must be reordered and execution condition must be created in the front transformation
|
||||
# to be aligned with the specification
|
||||
|
||||
# connect external input ports with body parameter nodes except current iteration
|
||||
# since it must be disconnected from external port
|
||||
for idx in range(1, len(body_parameters)):
|
||||
connect_body_input(loop_node, idx, body_parameters[idx])
|
||||
|
||||
# mark current iteration input Parameter node and execution condition Result node
|
||||
Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0])
|
||||
Loop.mark_execution_condition_result_node(loop_node, body_results[-1])
|
||||
|
||||
# connect back edges in the body except current iteration
|
||||
for idx in range(1, len(body_parameters)):
|
||||
Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx])
|
||||
|
||||
# connect body outputs with Loop operation output ports except the execution condition result
|
||||
for idx in range(len(body_results)-1):
|
||||
connect_body_output(loop_node, idx, body_results[idx])
|
||||
|
||||
# run function to parse body nodes attributes similar to the main graph
|
||||
extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
|
||||
return cls.enabled
|
||||
|
||||
|
||||
class StatelessWhileExtractor(FrontExtractorOp):
|
||||
"""
|
||||
The StatelessWhile operation is a variation of the while_loop primitive from TensorFlow 2 Python API.
|
||||
StatelessWhile does not have stateful operations in the body and condition graphs.
|
||||
"""
|
||||
op = 'StatelessWhile'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, loop_node):
|
||||
WhileExtractor.extract(loop_node)
|
||||
return cls.enabled
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
Copyright (C) 2020-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -42,6 +42,7 @@ from mo.utils.utils import refer_to_faq_msg
|
||||
|
||||
class TFLoader(Loader):
|
||||
enabled = True
|
||||
run_not_recursively = True
|
||||
|
||||
def load(self, graph: Graph):
|
||||
argv = graph.graph['cmd_params']
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2017-2020 Intel Corporation
|
||||
Copyright (C) 2017-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -94,6 +94,9 @@ class Loop(TensorIterator):
|
||||
loop_port_idx = record['external_port_id']
|
||||
if loop_port_idx != -1:
|
||||
input_shape = loop_node.in_port(loop_port_idx).get_connection().get_source().data.get_shape()
|
||||
slice_axis = record['axis']
|
||||
if slice_axis is not None:
|
||||
input_shape[slice_axis] = 1
|
||||
body_node.shape = input_shape
|
||||
log.debug('Updated shape for the body node with internal_id "{}" with value {}'
|
||||
''.format(record['internal_layer_id'], body_node.shape))
|
||||
@ -155,6 +158,8 @@ class Loop(TensorIterator):
|
||||
num_iterations = loop_node.in_port(0).data.get_value()
|
||||
if num_iterations is not None:
|
||||
num_iterations = num_iterations.item(0)
|
||||
if num_iterations < 0:
|
||||
return None
|
||||
return num_iterations
|
||||
|
||||
@staticmethod
|
||||
@ -317,9 +322,57 @@ class Loop(TensorIterator):
|
||||
if record[attr] == original_value:
|
||||
record[attr] = new_value
|
||||
matched += 1
|
||||
assert matched == 1, 'More than one record in the portmap for attr "{}" wil original value "{}"' \
|
||||
assert matched == 1, 'More than one record in the portmap for attr "{}" with original value "{}"' \
|
||||
''.format(attr, original_value)
|
||||
|
||||
@staticmethod
|
||||
def update_port_map_value_ext(port_map: dict, layer_id_attr: str, layer_id_value: int,
|
||||
updated_attr: str, new_attr_value: int):
|
||||
"""
|
||||
Updates a value of requested attribute for a certain layer id in a port map
|
||||
:param port_map: a map of external ports to internal layer ids
|
||||
:param layer_id_attr: layer id attribute for which to update attribute
|
||||
:param layer_id_value: layer id value for which to update attribute
|
||||
:param updated_attr: a name of attribute which to update
|
||||
:param new_attr_value: new value of attribute
|
||||
"""
|
||||
matched = 0
|
||||
for record in port_map:
|
||||
if record.get(layer_id_attr) == layer_id_value:
|
||||
record[updated_attr] = new_attr_value
|
||||
matched += 1
|
||||
assert matched == 1, 'More than one record in the portmap for attr "{}" with original value "{}"' \
|
||||
''.format(layer_id_attr, layer_id_value)
|
||||
|
||||
@staticmethod
|
||||
def back_edge_exists(back_edges_map: dict, from_layer: int, to_layer: int):
|
||||
"""
|
||||
Checks if a back edge exists in the back_edges_map connecting specific nodes
|
||||
:param back_edges_map: a map where to search for specified back edge
|
||||
:param from_layer: id of Result node that belongs a back edge
|
||||
:param to_layer: id of Parameter node that belongs a back edge
|
||||
:return: True or False
|
||||
"""
|
||||
for back_edge in back_edges_map:
|
||||
if back_edge['from_layer'] == from_layer and back_edge['to_layer'] == to_layer:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def inter_edge_exists(port_map: dict, external_port_id: int, internal_layer_id: int):
|
||||
"""
|
||||
Check if inter-graph edge (i.e. an edge between the main graph and body graph) exists
|
||||
:param port_map: a port map where to search for inter-graph edge
|
||||
:param external_port_id: port index from/to which edge goes
|
||||
:param internal_layer_id: layer id from/to which edge goes
|
||||
:return: True or False
|
||||
"""
|
||||
for i_port in port_map:
|
||||
if i_port['external_port_id'] == external_port_id and \
|
||||
i_port['internal_layer_id'] == internal_layer_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def re_numerate_input_ports(loop_node: Node):
|
||||
"""
|
||||
@ -372,7 +425,8 @@ class Loop(TensorIterator):
|
||||
new_port_id += 1
|
||||
|
||||
for port_idx_to_remove in reversed(range(new_port_id, max_port_id + 1)):
|
||||
loop_node.delete_output_port(port_idx_to_remove)
|
||||
if port_idx_to_remove in loop_node.out_ports().keys():
|
||||
loop_node.delete_output_port(port_idx_to_remove)
|
||||
|
||||
@staticmethod
|
||||
def remove_unused_ops_from_port_map(loop_node: Node, port_map: dict, port_map_attr: str, dir: [None, str] = None):
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -272,6 +272,17 @@ def protobuf_attrs(pb:tf_v1.NodeDef):
|
||||
|
||||
def protobuf2nx(graph, pb: tf_v1.GraphDef):
|
||||
fill_graph_with_nodes(graph, pb.node, get_id=lambda pb: pb.name, get_attrs=protobuf_attrs)
|
||||
|
||||
# Create a library with auxiliary functions used in TensorFlow 2 operations
|
||||
if hasattr(pb, 'library') and hasattr(pb.library, 'function'):
|
||||
graph.graph['library'] = {}
|
||||
for library_function in pb.library.function:
|
||||
function_name = library_function.signature.name
|
||||
graph.graph['library'][function_name] = {}
|
||||
graph.graph['library'][function_name]['input_arg'] = library_function.signature.input_arg
|
||||
graph.graph['library'][function_name]['output_arg'] = library_function.signature.output_arg
|
||||
graph.graph['library'][function_name]['node_def'] = library_function.node_def
|
||||
graph.graph['library'][function_name]['ret'] = library_function.ret
|
||||
# initial order of nodes in the GraphDef. It is used to specify order in
|
||||
# which merged nodes are added to the generated sub-graph GraphDef for the TensorFlow offload feature.
|
||||
graph.graph['initial_nodes_order'] = [node.name for node in pb.node]
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -120,7 +120,7 @@ class Node:
|
||||
# no handling of control flow edges -- TODO
|
||||
control_flow = False
|
||||
if not skip_if_absent and idx not in self.out_ports(control_flow=control_flow):
|
||||
raise Error("Input port with index {} doesn't exist in node {}.".format(idx, self.soft_get('name')))
|
||||
raise Error("Output port with index {} doesn't exist in node {}.".format(idx, self.soft_get('name')))
|
||||
if not self.out_port(idx).disconnected():
|
||||
self.out_port(idx).disconnect()
|
||||
del self._out_ports[idx]
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -21,6 +21,7 @@ from collections import namedtuple
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.extractor import add_attrs_props, update_ie_fields
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.utils import class_registration
|
||||
@ -445,7 +446,7 @@ class PermuteAttrs:
|
||||
# Exclude 3D shapes from permutation process: identity permutation
|
||||
perm = list(range(0, dims_number))
|
||||
inv = PermuteAttrs.get_inverse_permutation(perm)
|
||||
return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))
|
||||
return PermuteAttrs.Permutation(perm=int64_array(perm), inv=int64_array(inv))
|
||||
|
||||
@staticmethod
|
||||
def get_nchw_to_nhwc_permutation(dims_number: int):
|
||||
@ -456,4 +457,4 @@ class PermuteAttrs:
|
||||
# Exclude 3D shapes from permutation process: identity permutation
|
||||
perm = list(range(0, dims_number))
|
||||
inv = PermuteAttrs.get_inverse_permutation(perm)
|
||||
return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))
|
||||
return PermuteAttrs.Permutation(perm=int64_array(perm), inv=int64_array(inv))
|
||||
|
Loading…
Reference in New Issue
Block a user