[MO] Implement TensorFlow 2 While and Keras RNN support in MO (#3573)

* [MO] Implement TensorFlow 2 While support in MO

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add extractors for both While and StatelessWhile and do minor changes

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Improve update_body_graph function and manage graph names properly

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix a map for original name of parameters from body and cond

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Implement draft version of support of TF2 Keras RNN

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Implement Keras LSTM and GRU support in MO

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Improve code for Keras RNN support

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Finalize implementation of TF2 Keras RNN support in MO

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Apply the first part of the comments after review #1

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Avoid use of explicit values of port indices in the transformation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Finalize code after the first-round review

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Apply comments after the second-round review

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2021-01-21 17:39:57 +03:00 committed by GitHub
parent 61ccde700f
commit bacb8420f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 618 additions and 13 deletions

View File

@ -404,6 +404,7 @@ extensions/front/tf/identity_ext.py
extensions/front/tf/identityN_to_identity.py
extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/KerasRNNTransformation.py
extensions/front/tf/log_softmax_ext.py
extensions/front/tf/LookupTableInsert_ext.py
extensions/front/tf/LoopCond_ext.py
@ -483,6 +484,8 @@ extensions/front/tf/UnpackPackReverseInputChannels.py
extensions/front/tf/variable_ext.py
extensions/front/tf/variables_values_freezing.py
extensions/front/tf/WhereDecomposition.py
extensions/front/tf/while_ext.py
extensions/front/tf/WhileNormalize.py
extensions/front/tf/yolo_v1.json
extensions/front/tf/yolo_v1_tiny.json
extensions/front/tf/yolo_v2.json

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,7 +15,6 @@
"""
import logging as log
from collections import defaultdict
from copy import copy
import numpy as np
@ -125,6 +124,11 @@ class RemoveConstToResult(BackReplacementPattern):
"""
enabled = True
force_clean_up = True
# TODO: remove this transformation once all plugins support constant value network.
# Do not run recursively since Const->Result sub-graph can be encountered in a body graph of Loop node
# and this sub-graph is needed to avoid dynamism created by Loop node
# in case using axis in output port map
run_not_recursively = True
@staticmethod
def pattern():

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -22,6 +22,9 @@ from mo.graph.graph import Graph
class StandaloneConstEraser(FrontReplacementSubgraph):
enabled = True
# TODO: remove this transformation once all plugins support constant value network.
# Now it avoids to be run recursively since Const->Result sub-graph can be encountered in a body graph of Loop node
run_not_recursively = True
@staticmethod
def pattern():

View File

@ -0,0 +1,268 @@
"""
Copyright (C) 2017-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.front.tf.WhileNormalize import WhileNormalize
from extensions.ops.loop import Loop
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
from mo.ops.const import Const
from mo.ops.squeeze import Squeeze
from mo.ops.unsqueeze import Unsqueeze
def compute_input_port_idx(req_node: Node, loop_node: Node):
"""
Computes input port index by which requested node is passed to Loop node
:param req_node: a node for which to find input port index is requested
:param loop_node: a node that can receive input data from requested node by some input port
:return: input port index
"""
for destination in req_node.out_port(0).get_destinations():
if loop_node.id == destination.node.id:
return destination.idx
return None
def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict):
"""
Finds sub-graph matches corresponding pattern in graph
:param graph: a graph where to search for matched sub-graph
:param body_pattern: a pattern
:return: a list of sub-graph matches
"""
matches = []
for match in find_pattern_matches(graph, **body_pattern):
match = inverse_dict(match)
for k in match:
match[k] = Node(graph, match[k])
matches.append(match)
return matches
class KerasRNNInputSlicing(FrontReplacementSubgraph):
"""
The transformation detects TensorFlow 2 pattern that corresponds to subsequent slicing of input.
It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original sub-graph
by adding axis attribute for corresponding input port of Loop node.
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
"""
enabled = True
def run_before(self):
return [WhileNormalize]
@staticmethod
def pattern(**kwargs):
return dict(
nodes=[('unstack', dict(op='TensorListFromTensor')),
('while', dict(op='Loop'))],
edges=[('unstack', 'while')]
)
@staticmethod
def get_body_pattern():
return dict(
nodes=[('tensor_list', dict(op='Parameter')),
('current_iteration', dict(op='Parameter')),
('slicing', dict(op='TensorListGetItem')),
('const_increment', dict(op='Const')),
('increment_iteration', dict(op='Add')),
('increment_iteration_identity', dict(op='Identity')),
('increment_iteration_result', dict(op='Result'))],
edges=[('tensor_list', 'slicing', {'in': 0}),
('current_iteration', 'slicing', {'in': 1}),
('const_increment', 'increment_iteration', {'in': 1}),
('current_iteration', 'increment_iteration', {'in': 0}),
('increment_iteration', 'increment_iteration_identity', {'in': 0}),
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
)
@staticmethod
def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for input slicing
:param internal_match: a match used for handling a part of the body graph responsible for input slicing
"""
loop_node = external_match['while']
unstack_node = external_match['unstack']
body_graph = loop_node['body']
tensor_list_get_item_node = internal_match['slicing']
unstack_placeholder = internal_match['tensor_list']
tensor_list_get_item_node_name = tensor_list_get_item_node.soft_get('name', tensor_list_get_item_node.id)
# 1. process the body graph to avoid unsupported operations: TensorListGetItem and TensorListSetItem
# replace TensorListGetItem with Squeeze node and iterate through slices using axis for input port
squeeze_list_element = create_op_with_const_inputs(body_graph, Squeeze, {1: int64_array(0)},
{'name': 'TensorListGetItemSqueeze'})
tensor_list_get_item_node.in_port(0).get_connection().set_destination(squeeze_list_element.in_port(0))
tensor_list_get_item_node.out_port(0).get_connection().set_source(squeeze_list_element.out_port(0))
rename_nodes([(tensor_list_get_item_node, tensor_list_get_item_node_name + '/AbandonedName'),
(squeeze_list_element, tensor_list_get_item_node_name)])
unstack_placeholder_layer_id = unstack_placeholder.internal_layer_id
Loop.update_port_map_value_ext(loop_node.input_port_map, 'internal_layer_id', unstack_placeholder_layer_id,
'axis', 0)
# 2. process locality of Loop node in the main graph to avoid unsupported operations:
# TensorListFromTensor, TensorListReserve, and TensorListStack
# remove TensorListFromTensor and pass a tensor to Loop as is
unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source())
def replace_sub_graph(self, graph: Graph, external_match: dict):
loop_node = external_match['while']
body_graph = loop_node['body']
body_pattern = KerasRNNInputSlicing.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
# a case of multiple matches is not handled since it is not clear how to select corresponding match
if len(internal_matches) == 1:
internal_match = internal_matches[0]
loop_node = external_match['while']
unstack_port_idx = compute_input_port_idx(external_match['unstack'], loop_node)
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx,
internal_match['tensor_list'].internal_layer_id):
# only if inter-graph match passed it starts to process the sub-graph
KerasRNNInputSlicing.transform_keras_rnn_input_slicing(external_match, internal_match)
class KerasRNNOutputConcatenation(FrontReplacementSubgraph):
"""
The transformation detects TensorFlow 2 pattern that corresponds to concatenation of intermediate results
generated in each iteration of While operation.
It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces the original sub-graph
by adding axis attribute for corresponding output port of Loop node.
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
"""
enabled = True
def run_before(self):
return [WhileNormalize]
@staticmethod
def pattern(**kwargs):
return dict(
nodes=[('reserve', dict(op='TensorListReserve')),
('while', dict(op='Loop')),
('stack', dict(op='TensorListStack'))],
edges=[('reserve', 'while'),
('while', 'stack')]
)
@staticmethod
def get_body_pattern():
return dict(
nodes=[('container', dict(op='Parameter')),
('current_iteration', dict(op='Parameter')),
('const_increment', dict(op='Const')),
('increment_iteration', dict(op='Add')),
('increment_iteration_identity', dict(op='Identity')),
('increment_iteration_result', dict(op='Result')),
('concatenation', dict(op='TensorListSetItem')),
('concatenation_identity', dict(op='Identity')),
('concatenation_result', dict(op='Result')),
],
edges=[('const_increment', 'increment_iteration', {'in': 1}),
('current_iteration', 'increment_iteration', {'in': 0}),
('container', 'concatenation', {'in': 0}),
('current_iteration', 'concatenation', {'in': 1}),
('concatenation', 'concatenation_identity', {'in': 0}),
('concatenation_identity', 'concatenation_result', {'in': 0}),
('increment_iteration', 'increment_iteration_identity', {'in': 0}),
('increment_iteration_identity', 'increment_iteration_result', {'in': 0})]
)
@staticmethod
def transform_keras_rnn_output_concatenation(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for output concatenation
:param internal_match: a match used for handling a part of the body graph responsible for output concatenation
"""
loop_node = external_match['while']
stack_node = external_match['stack']
list_reserve_node = external_match['reserve']
body_graph = loop_node['body']
tensor_list_set_item_node = internal_match['concatenation']
tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get('name', tensor_list_set_item_node.id)
list_result_node = internal_match['concatenation_result']
# replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node
# to concatenate results from different iterations
unsqueeze_list_element = create_op_with_const_inputs(body_graph, Unsqueeze, {1: int64_array(0)},
{'name': 'TensorListSetItemUnsqueeze'})
tensor_list_set_item_node.in_port(2).get_connection().set_destination(unsqueeze_list_element.in_port(0))
tensor_list_set_item_node.out_port(0).get_connection().set_source(unsqueeze_list_element.out_port(0))
rename_nodes([(tensor_list_set_item_node, tensor_list_set_item_node_name + '/AbandonedName'),
(unsqueeze_list_element, tensor_list_set_item_node_name)])
list_result_node_layer_id = list_result_node.internal_layer_id
Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id,
'axis', 0)
# remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated
stack_node.out_port(0).get_connection().set_source(stack_node.in_port(0).get_connection().get_source())
# disconnect ListReserve node because it is no longer needed for Loop
list_reserve_node.out_port(0).disconnect()
# connect a number of iterations with trip count that can be received from the second input of ListReserve
# create a constant network with True value for execution_condition so that IE can ignore execution condition
# and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism.
loop_node.in_port(1).disconnect()
list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1))
for record in loop_node.output_port_map:
if 'purpose' in record and record['purpose'] == 'execution_condition':
exec_cond_layer_id = record['internal_layer_id']
exec_cond_node = Loop.get_body_node_by_internal_id(loop_node, exec_cond_layer_id)
const_true = Const(body_graph, {'value': np.array(True, dtype=np.bool)}).create_node()
exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0))
def replace_sub_graph(self, graph: Graph, external_match: dict):
loop_node = external_match['while']
body_graph = loop_node['body']
body_pattern = KerasRNNOutputConcatenation.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)
if len(internal_matches) == 1:
internal_match = internal_matches[0]
reserve_port_idx = compute_input_port_idx(external_match['reserve'], loop_node)
stack_port_idx = external_match['stack'].in_port(0).get_source().idx
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
# check connections between body output ports and external output ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
internal_match['container'].internal_layer_id) and \
Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx,
internal_match['container'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx,
internal_match['concatenation_result'].internal_layer_id):
KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(external_match, internal_match)

View File

@ -0,0 +1,53 @@
"""
Copyright (C) 2017-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.loop import Loop
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph, Node
from mo.ops.const import Const
class WhileNormalize(FrontReplacementSubgraph):
"""
Normalize inputs for Loop replacing TensorFlow 2 While operation:
1) Remove external input port for current iteration
2) Move trip count from port #1 to port #0
3) Occupy port #1 for execution condition
"""
enabled = True
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='Loop'):
self.normalize_loop_node(graph, node)
@staticmethod
def normalize_loop_node(graph: Graph, loop_node: Node):
loop_name = loop_node.soft_get('name', loop_node.id)
# disconnect current iteration from external port #0 and move trip count to this port
loop_node.in_port(0).disconnect()
loop_node.in_port(1).get_connection().add_destination(loop_node.in_port(0))
Loop.update_port_map_value(loop_node.input_port_map, 'external_port_id', 1, 0)
# connect execution condition port
exec_cond_node = Const(graph, {'name': loop_name + '/ExecutionConditionValue',
'value': np.array(True, dtype=np.bool)}).create_node()
loop_node.in_port(1).get_connection().set_source(exec_cond_node.out_port(0))
loop_node.body.clean_up()
Loop.normalize_input_output_ports(loop_node)

View File

@ -0,0 +1,207 @@
"""
Copyright (C) 2017-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
from extensions.front.onnx.loop_ext import connect_body_input, connect_body_output
from extensions.ops.loop import Loop
from extensions.ops.parameter import Parameter
from mo.front.common.register_custom_ops import check_for_duplicates
from mo.front.extractor import extract_node_attrs, FrontExtractorOp
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors
from mo.front.tf.extractors.utils import tf_dtype_extractor
from mo.graph.graph import add_opoutput, Graph, Node
from mo.ops.op import PermuteAttrs
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
body_parameter_names: list, body_results: list):
"""
Updates the loop body graph with a sub-graph (for body or condition functions)
:param body_graph: a loop body graph to be updated
:param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
:param body_parameter_names: a (unchanged) list of parameters in the loop body graph
:param body_results: a list of Result nodes that is extended with a list from a sub-graph
"""
# create a map from a node name in original model to a name in a loop body graph assuming
# that names in the original model are unique
# initially, the map contains names for parameters that are common for the body and condition graphs
map_original_name = {}
for idx, pb_node in enumerate(subgraph_proto['input_arg']):
map_original_name[pb_node.name] = body_parameter_names[idx]
# walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
for pb_node in subgraph_proto['node_def']:
# create an NX node
id = body_graph.unique_id(pb_node.name)
map_original_name[pb_node.name] = id
body_graph.add_node(id, pb=pb_node, kind='op')
# add incoming edges based on data_nodes_map
for dst_port, inp in enumerate(pb_node.input):
orig_src_id = inp.split(":")[0]
src_id = map_original_name[orig_src_id]
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
assert (body_graph.has_node(src_id))
edge_attrs = {
'out': src_port,
'in': dst_port,
'name': src_id,
'fw_tensor_debug_info': [(src_id, src_port)],
'in_attrs': ['in', 'name'],
'out_attrs': ['out', 'name'],
'data_attrs': ['fw_tensor_debug_info']
}
body_graph.add_edge(src_id, id, **edge_attrs)
# create Result nodes in the loop body graph
for output in subgraph_proto['output_arg']:
output_name = subgraph_proto['ret'][output.name]
orig_src_id = output_name.split(":")[0]
src_id = map_original_name[orig_src_id]
src_port = 0 if len(output_name.split(":")) == 1\
else int(output_name.split(":")[-1])
assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format(
src_id)
body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
class WhileExtractor(FrontExtractorOp):
"""
The While operation is a variation of the while_loop primitive from TensorFlow 2 Python API.
While can have stateful operations in the body and condition graphs that does not influence on inference so
the logic for handling While and StatelessWhile (see below) is the same.
"""
op = 'While'
enabled = True
@classmethod
def extract(cls, loop_node):
Loop.update_node_stat(loop_node, {})
loop_name = loop_node.soft_get('name', loop_node.id)
# check that required body and condition functions exist in the graph library
main_graph = loop_node.graph
body_graph_name = loop_node.pb.attr['body'].func.name
cond_graph_name = loop_node.pb.attr['cond'].func.name
assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \
'by node with name "{}".'.format(loop_name)
library_graph = main_graph.graph['library']
assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
'that is required by node ' \
'with name "{}".'.format(body_graph_name, loop_name)
body_graph_proto = library_graph[body_graph_name]
assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
'that is required by node ' \
'with name "{}".'.format(cond_graph_name, loop_name)
cond_graph_proto = library_graph[cond_graph_name]
body_graph = Graph()
# fill the body graph
for attr_key in main_graph.graph.keys():
if attr_key != 'library':
body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key])
else:
# it is sufficient to have a link to the library
body_graph.graph['library'] = main_graph.graph['library']
loop_node['body'] = body_graph
# create Parameter nodes for the body graph
body_parameters = []
body_parameter_names = []
for idx, pb_node in enumerate(body_graph_proto['input_arg']):
param_id = body_graph.unique_id(pb_node.name)
body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
parameter_node = Node(body_graph, pb_node.name)
Parameter.update_node_stat(parameter_node,
{'data_type': tf_dtype_extractor(pb_node.type),
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
)
body_parameters.append(parameter_node)
body_parameter_names.append(param_id)
# update the loop body graph with the body function graph
body_results = []
update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results)
# update the loop body graph with the condition function graph
update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results)
# add 'internal_layer_id' attribute which is a must have attribute for the loop body node
for idx, body_node in enumerate(body_graph.get_op_nodes()):
body_node['internal_layer_id'] = idx
body_graph.stage = 'front'
# Currently,
# Loop Inputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# Body Inputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# Body Outputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# Loop Outputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# so inputs must be reordered and execution condition must be created in the front transformation
# to be aligned with the specification
# connect external input ports with body parameter nodes except current iteration
# since it must be disconnected from external port
for idx in range(1, len(body_parameters)):
connect_body_input(loop_node, idx, body_parameters[idx])
# mark current iteration input Parameter node and execution condition Result node
Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0])
Loop.mark_execution_condition_result_node(loop_node, body_results[-1])
# connect back edges in the body except current iteration
for idx in range(1, len(body_parameters)):
Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx])
# connect body outputs with Loop operation output ports except the execution condition result
for idx in range(len(body_results)-1):
connect_body_output(loop_node, idx, body_results[idx])
# run function to parse body nodes attributes similar to the main graph
extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
return cls.enabled
class StatelessWhileExtractor(FrontExtractorOp):
"""
The StatelessWhile operation is a variation of the while_loop primitive from TensorFlow 2 Python API.
StatelessWhile does not have stateful operations in the body and condition graphs.
"""
op = 'StatelessWhile'
enabled = True
@classmethod
def extract(cls, loop_node):
WhileExtractor.extract(loop_node)
return cls.enabled

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2020 Intel Corporation
Copyright (C) 2020-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -42,6 +42,7 @@ from mo.utils.utils import refer_to_faq_msg
class TFLoader(Loader):
enabled = True
run_not_recursively = True
def load(self, graph: Graph):
argv = graph.graph['cmd_params']

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2017-2020 Intel Corporation
Copyright (C) 2017-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -94,6 +94,9 @@ class Loop(TensorIterator):
loop_port_idx = record['external_port_id']
if loop_port_idx != -1:
input_shape = loop_node.in_port(loop_port_idx).get_connection().get_source().data.get_shape()
slice_axis = record['axis']
if slice_axis is not None:
input_shape[slice_axis] = 1
body_node.shape = input_shape
log.debug('Updated shape for the body node with internal_id "{}" with value {}'
''.format(record['internal_layer_id'], body_node.shape))
@ -155,6 +158,8 @@ class Loop(TensorIterator):
num_iterations = loop_node.in_port(0).data.get_value()
if num_iterations is not None:
num_iterations = num_iterations.item(0)
if num_iterations < 0:
return None
return num_iterations
@staticmethod
@ -317,9 +322,57 @@ class Loop(TensorIterator):
if record[attr] == original_value:
record[attr] = new_value
matched += 1
assert matched == 1, 'More than one record in the portmap for attr "{}" wil original value "{}"' \
assert matched == 1, 'More than one record in the portmap for attr "{}" with original value "{}"' \
''.format(attr, original_value)
@staticmethod
def update_port_map_value_ext(port_map: dict, layer_id_attr: str, layer_id_value: int,
updated_attr: str, new_attr_value: int):
"""
Updates a value of requested attribute for a certain layer id in a port map
:param port_map: a map of external ports to internal layer ids
:param layer_id_attr: layer id attribute for which to update attribute
:param layer_id_value: layer id value for which to update attribute
:param updated_attr: a name of attribute which to update
:param new_attr_value: new value of attribute
"""
matched = 0
for record in port_map:
if record.get(layer_id_attr) == layer_id_value:
record[updated_attr] = new_attr_value
matched += 1
assert matched == 1, 'More than one record in the portmap for attr "{}" with original value "{}"' \
''.format(layer_id_attr, layer_id_value)
@staticmethod
def back_edge_exists(back_edges_map: dict, from_layer: int, to_layer: int):
"""
Checks if a back edge exists in the back_edges_map connecting specific nodes
:param back_edges_map: a map where to search for specified back edge
:param from_layer: id of Result node that belongs a back edge
:param to_layer: id of Parameter node that belongs a back edge
:return: True or False
"""
for back_edge in back_edges_map:
if back_edge['from_layer'] == from_layer and back_edge['to_layer'] == to_layer:
return True
return False
@staticmethod
def inter_edge_exists(port_map: dict, external_port_id: int, internal_layer_id: int):
"""
Check if inter-graph edge (i.e. an edge between the main graph and body graph) exists
:param port_map: a port map where to search for inter-graph edge
:param external_port_id: port index from/to which edge goes
:param internal_layer_id: layer id from/to which edge goes
:return: True or False
"""
for i_port in port_map:
if i_port['external_port_id'] == external_port_id and \
i_port['internal_layer_id'] == internal_layer_id:
return True
return False
@staticmethod
def re_numerate_input_ports(loop_node: Node):
"""
@ -372,7 +425,8 @@ class Loop(TensorIterator):
new_port_id += 1
for port_idx_to_remove in reversed(range(new_port_id, max_port_id + 1)):
loop_node.delete_output_port(port_idx_to_remove)
if port_idx_to_remove in loop_node.out_ports().keys():
loop_node.delete_output_port(port_idx_to_remove)
@staticmethod
def remove_unused_ops_from_port_map(loop_node: Node, port_map: dict, port_map_attr: str, dir: [None, str] = None):

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -272,6 +272,17 @@ def protobuf_attrs(pb:tf_v1.NodeDef):
def protobuf2nx(graph, pb: tf_v1.GraphDef):
fill_graph_with_nodes(graph, pb.node, get_id=lambda pb: pb.name, get_attrs=protobuf_attrs)
# Create a library with auxiliary functions used in TensorFlow 2 operations
if hasattr(pb, 'library') and hasattr(pb.library, 'function'):
graph.graph['library'] = {}
for library_function in pb.library.function:
function_name = library_function.signature.name
graph.graph['library'][function_name] = {}
graph.graph['library'][function_name]['input_arg'] = library_function.signature.input_arg
graph.graph['library'][function_name]['output_arg'] = library_function.signature.output_arg
graph.graph['library'][function_name]['node_def'] = library_function.node_def
graph.graph['library'][function_name]['ret'] = library_function.ret
# initial order of nodes in the GraphDef. It is used to specify order in
# which merged nodes are added to the generated sub-graph GraphDef for the TensorFlow offload feature.
graph.graph['initial_nodes_order'] = [node.name for node in pb.node]

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -120,7 +120,7 @@ class Node:
# no handling of control flow edges -- TODO
control_flow = False
if not skip_if_absent and idx not in self.out_ports(control_flow=control_flow):
raise Error("Input port with index {} doesn't exist in node {}.".format(idx, self.soft_get('name')))
raise Error("Output port with index {} doesn't exist in node {}.".format(idx, self.soft_get('name')))
if not self.out_port(idx).disconnected():
self.out_port(idx).disconnect()
del self._out_ports[idx]

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@ from collections import namedtuple
import networkx as nx
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import add_attrs_props, update_ie_fields
from mo.graph.graph import Node, Graph
from mo.utils import class_registration
@ -445,7 +446,7 @@ class PermuteAttrs:
# Exclude 3D shapes from permutation process: identity permutation
perm = list(range(0, dims_number))
inv = PermuteAttrs.get_inverse_permutation(perm)
return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))
return PermuteAttrs.Permutation(perm=int64_array(perm), inv=int64_array(inv))
@staticmethod
def get_nchw_to_nhwc_permutation(dims_number: int):
@ -456,4 +457,4 @@ class PermuteAttrs:
# Exclude 3D shapes from permutation process: identity permutation
perm = list(range(0, dims_number))
inv = PermuteAttrs.get_inverse_permutation(perm)
return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))
return PermuteAttrs.Permutation(perm=int64_array(perm), inv=int64_array(inv))