Files
openvino/model-optimizer/extensions/front/tf/while_ext.py
2021-03-05 16:41:36 +03:00

213 lines
9.6 KiB
Python

"""
Copyright (C) 2017-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
from extensions.ops.loop import Loop
from extensions.ops.parameter import Parameter
from mo.front.common.register_custom_ops import check_for_duplicates
from mo.front.extractor import extract_node_attrs, FrontExtractorOp
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors
from mo.front.tf.extractors.utils import tf_dtype_extractor
from mo.graph.graph import add_opoutput, Graph, Node
from mo.ops.op import PermuteAttrs
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
body_parameter_names: list, body_results: list):
"""
Updates the loop body graph with a sub-graph (for body or condition functions)
:param body_graph: a loop body graph to be updated
:param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
:param body_parameter_names: a (unchanged) list of parameters in the loop body graph
:param body_results: a list of Result nodes that is extended with a list from a sub-graph
"""
# create a map from a node name in original model to a name in a loop body graph assuming
# that names in the original model are unique
# initially, the map contains names for parameters that are common for the body and condition graphs
map_original_name = {}
for idx, pb_node in enumerate(subgraph_proto['input_arg']):
map_original_name[pb_node.name] = body_parameter_names[idx]
# walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
for pb_node in subgraph_proto['node_def']:
# create an NX node
id = body_graph.unique_id(pb_node.name)
map_original_name[pb_node.name] = id
body_graph.add_node(id, pb=pb_node, kind='op')
# add incoming edges based on data_nodes_map
for dst_port, inp in enumerate(pb_node.input):
orig_src_id = inp.split(":")[0]
# TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
# skip control flow dependency
if orig_src_id[0] == '^':
continue
src_id = map_original_name[orig_src_id]
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
assert (body_graph.has_node(src_id))
edge_attrs = {
'out': src_port,
'in': dst_port,
'name': src_id,
'fw_tensor_debug_info': [(src_id, src_port)],
'in_attrs': ['in', 'name'],
'out_attrs': ['out', 'name'],
'data_attrs': ['fw_tensor_debug_info']
}
body_graph.add_edge(src_id, id, **edge_attrs)
# create Result nodes in the loop body graph
for output in subgraph_proto['output_arg']:
output_name = subgraph_proto['ret'][output.name]
orig_src_id = output_name.split(":")[0]
src_id = map_original_name[orig_src_id]
src_port = 0 if len(output_name.split(":")) == 1\
else int(output_name.split(":")[-1])
assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format(
src_id)
body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
class WhileExtractor(FrontExtractorOp):
"""
The While operation is a variation of the while_loop primitive from TensorFlow 2 Python API.
While can have stateful operations in the body and condition graphs that does not influence on inference so
the logic for handling While and StatelessWhile (see below) is the same.
"""
op = 'While'
enabled = True
@classmethod
def extract(cls, loop_node):
Loop.update_node_stat(loop_node, {})
loop_name = loop_node.soft_get('name', loop_node.id)
# check that required body and condition functions exist in the graph library
main_graph = loop_node.graph
body_graph_name = loop_node.pb.attr['body'].func.name
cond_graph_name = loop_node.pb.attr['cond'].func.name
assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \
'by node with name "{}".'.format(loop_name)
library_graph = main_graph.graph['library']
assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
'that is required by node ' \
'with name "{}".'.format(body_graph_name, loop_name)
body_graph_proto = library_graph[body_graph_name]
assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
'that is required by node ' \
'with name "{}".'.format(cond_graph_name, loop_name)
cond_graph_proto = library_graph[cond_graph_name]
body_graph = Graph()
# fill the body graph
for attr_key in main_graph.graph.keys():
if attr_key != 'library':
body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key])
else:
# it is sufficient to have a link to the library
body_graph.graph['library'] = main_graph.graph['library']
loop_node['body'] = body_graph
# create Parameter nodes for the body graph
body_parameters = []
body_parameter_names = []
for idx, pb_node in enumerate(body_graph_proto['input_arg']):
param_id = body_graph.unique_id(pb_node.name)
body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
parameter_node = Node(body_graph, pb_node.name)
Parameter.update_node_stat(parameter_node,
{'data_type': tf_dtype_extractor(pb_node.type),
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
)
body_parameters.append(parameter_node)
body_parameter_names.append(param_id)
# update the loop body graph with the body function graph
body_results = []
update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results)
# update the loop body graph with the condition function graph
update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results)
# add 'internal_layer_id' attribute which is a must have attribute for the loop body node
for idx, body_node in enumerate(body_graph.get_op_nodes()):
body_node['internal_layer_id'] = idx
body_graph.stage = 'front'
# Currently,
# Loop Inputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# Body Inputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# Body Outputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# Loop Outputs Order:
# 0 - current iteration
# 1 - trip count
# 2.. - "loop carried" dependencies variables
#
# so inputs must be reordered and execution condition must be created in the front transformation
# to be aligned with the specification
# connect external input ports with body parameter nodes except current iteration
# since it must be disconnected from external port
for idx in range(1, len(body_parameters)):
Loop.connect_body_input(loop_node, idx, body_parameters[idx])
# mark current iteration input Parameter node and execution condition Result node
Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0])
Loop.mark_execution_condition_result_node(loop_node, body_results[-1])
# connect back edges in the body except current iteration
for idx in range(1, len(body_parameters)):
Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx])
# connect body outputs with Loop operation output ports except the execution condition result
for idx in range(len(body_results)-1):
Loop.connect_body_output(loop_node, idx, body_results[idx])
# run function to parse body nodes attributes similar to the main graph
extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
return cls.enabled
class StatelessWhileExtractor(FrontExtractorOp):
"""
The StatelessWhile operation is a variation of the while_loop primitive from TensorFlow 2 Python API.
StatelessWhile does not have stateful operations in the body and condition graphs.
"""
op = 'StatelessWhile'
enabled = True
@classmethod
def extract(cls, loop_node):
WhileExtractor.extract(loop_node)
return cls.enabled