Mo implementation for If with tf extractor (#6662)

* Add tf2.x impl for If

* Fix ir_engine

* Fix opset

* Fix BOM file

* Added new test

* Fix comments

* Add subgraph_utils

* Fix comments

* Fix transform

* code refactoring

* Fix description

* rewrite support for empty tensor in if

* added onnx extractor

* delete onnx_if

* fix bug with fake_outputs

* Fix test

* Fix control_flow and fix commentaries

* create method results_mapping_and_finding_fake_outputs(output_nodes_in_subgraph,
This commit is contained in:
Eugeny Volosenkov 2021-08-19 10:13:21 +03:00 committed by GitHub
parent 506148cff6
commit 38022c4cd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 750 additions and 129 deletions

View File

@ -436,6 +436,7 @@ extensions/front/tf/GatherTree_ext.py
extensions/front/tf/GNMT_DynamicSequenceLengths.py extensions/front/tf/GNMT_DynamicSequenceLengths.py
extensions/front/tf/identity_ext.py extensions/front/tf/identity_ext.py
extensions/front/tf/identityN_to_identity.py extensions/front/tf/identityN_to_identity.py
extensions/front/tf/if_ext.py
extensions/front/tf/InterpolateTransposes.py extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/log_softmax_ext.py extensions/front/tf/log_softmax_ext.py
@ -701,6 +702,7 @@ extensions/ops/GRU.py
extensions/ops/GRUCell.py extensions/ops/GRUCell.py
extensions/ops/hard_sigmoid.py extensions/ops/hard_sigmoid.py
extensions/ops/identity.py extensions/ops/identity.py
extensions/ops/If.py
extensions/ops/instance_normalization.py extensions/ops/instance_normalization.py
extensions/ops/interp.py extensions/ops/interp.py
extensions/ops/interpolate.py extensions/ops/interpolate.py
@ -927,6 +929,7 @@ mo/front/tf/extractors/native_tf.py
mo/front/tf/extractors/pack.py mo/front/tf/extractors/pack.py
mo/front/tf/extractors/random_uniform.py mo/front/tf/extractors/random_uniform.py
mo/front/tf/extractors/strided_slice.py mo/front/tf/extractors/strided_slice.py
mo/front/tf/extractors/subgraph_utils.py
mo/front/tf/extractors/utils.py mo/front/tf/extractors/utils.py
mo/front/tf/graph_utils.py mo/front/tf/graph_utils.py
mo/front/tf/loader.py mo/front/tf/loader.py
@ -1050,6 +1053,7 @@ mo/utils/ir_reader/extenders/experimental_extender.py
mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py
mo/utils/ir_reader/extenders/fakequantize_extender.py mo/utils/ir_reader/extenders/fakequantize_extender.py
mo/utils/ir_reader/extenders/GRUCell_extender.py mo/utils/ir_reader/extenders/GRUCell_extender.py
mo/utils/ir_reader/extenders/if_extender.py
mo/utils/ir_reader/extenders/interpolate_extender.py mo/utils/ir_reader/extenders/interpolate_extender.py
mo/utils/ir_reader/extenders/loop_extender.py mo/utils/ir_reader/extenders/loop_extender.py
mo/utils/ir_reader/extenders/LSTMCell_extender.py mo/utils/ir_reader/extenders/LSTMCell_extender.py

View File

@ -0,0 +1,69 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.If import If
from extensions.ops.parameter import Parameter
from mo.front.common.register_custom_ops import check_for_duplicates
from mo.front.extractor import FrontExtractorOp, extract_node_attrs
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors
from mo.front.tf.extractors.subgraph_utils import update_body_graph, convert_graph_inputs_to_parameters, \
get_graph_proto, create_internal_graph
from mo.graph.graph import Node, Graph
def extract_if(cls, if_node: Node):
If.update_node_stat(if_node, {})
# check that required body and condition functions exist in the graph library
main_graph = if_node.graph
then_graph_proto = get_graph_proto(main_graph, 'then_branch', if_node)
else_graph_proto = get_graph_proto(main_graph, 'else_branch', if_node)
then_graph = create_internal_graph(main_graph)
if_node['then_graph'] = then_graph
else_graph = create_internal_graph(main_graph)
if_node['else_graph'] = else_graph
# create Parameter nodes for the then/else graphs
for input_index, (body_graph, body_graph_proto) in enumerate(zip((then_graph, else_graph), (then_graph_proto,
else_graph_proto))):
body_parameters, body_parameter_names = convert_graph_inputs_to_parameters(body_graph, body_graph_proto)
# update the If body graph with the body function graph
body_results = []
update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results)
body_graph.stage = 'front'
# connect external input ports with body parameter nodes except input with condition
for idx in range(0, len(body_parameters)):
If.connect_body_input(if_node, not input_index, idx + 1, body_parameters[idx])
# connect body outputs with If operation output ports
for idx in range(len(body_results)):
If.connect_body_output(if_node, not input_index, idx, body_results[idx])
# run function to parse body nodes attributes similar to the main graph
extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
return cls.enabled
class IfExtractor(FrontExtractorOp):
op = 'If'
enabled = True
@classmethod
def extract(cls, if_node: Node):
return extract_if(cls, if_node)
class StatelessIfExtractor(FrontExtractorOp):
op = 'StatelessIf'
enabled = True
@classmethod
def extract(cls, if_node: Node):
return extract_if(cls, if_node)

View File

@ -1,68 +1,14 @@
# Copyright (C) 2018-2021 Intel Corporation # Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
from extensions.ops.loop import Loop from extensions.ops.loop import Loop
from extensions.ops.parameter import Parameter from extensions.ops.parameter import Parameter
from mo.front.common.register_custom_ops import check_for_duplicates from mo.front.common.register_custom_ops import check_for_duplicates
from mo.front.extractor import extract_node_attrs, FrontExtractorOp from mo.front.extractor import extract_node_attrs, FrontExtractorOp
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors, create_tf_edge from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors, create_tf_edge
from mo.front.tf.extractors.utils import tf_dtype_extractor from mo.front.tf.extractors.subgraph_utils import update_body_graph, convert_graph_inputs_to_parameters, \
get_graph_proto, create_internal_graph
from mo.graph.graph import add_opoutput, Graph, Node from mo.graph.graph import add_opoutput, Graph, Node
from mo.ops.op import PermuteAttrs
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
body_parameter_names: list, body_results: list):
"""
Updates the loop body graph with a sub-graph (for body or condition functions)
:param body_graph: a loop body graph to be updated
:param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
:param body_parameter_names: a (unchanged) list of parameters in the loop body graph
:param body_results: a list of Result nodes that is extended with a list from a sub-graph
"""
# create a map from a node name in original model to a name in a loop body graph assuming
# that names in the original model are unique
# initially, the map contains names for parameters that are common for the body and condition graphs
map_original_name = {}
for idx, pb_node in enumerate(subgraph_proto['input_arg']):
map_original_name[pb_node.name] = body_parameter_names[idx]
# walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
for pb_node in subgraph_proto['node_def']:
# create an NX node
id = body_graph.unique_id(pb_node.name)
map_original_name[pb_node.name] = id
body_graph.add_node(id, pb=pb_node, kind='op')
if hasattr(body_graph, 'op_names_statistic') and hasattr(pb_node, 'op'):
body_graph.op_names_statistic[pb_node.op] += 1
# add incoming edges based on data_nodes_map
for dst_port, inp in enumerate(pb_node.input):
orig_src_id = inp.split(":")[0]
# TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
# skip control flow dependency
if orig_src_id[0] == '^':
continue
src_id = map_original_name[orig_src_id]
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
assert (body_graph.has_node(src_id))
body_graph.add_edges_from([create_tf_edge(src_id + ":" + str(src_port), id, dst_port)])
# create Result nodes in the loop body graph
for output in subgraph_proto['output_arg']:
output_name = subgraph_proto['ret'][output.name]
orig_src_id = output_name.split(":")[0]
src_id = map_original_name[orig_src_id]
src_port = 0 if len(output_name.split(":")) == 1\
else int(output_name.split(":")[-1])
assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format(
src_id)
body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
class WhileExtractor(FrontExtractorOp): class WhileExtractor(FrontExtractorOp):
@ -77,49 +23,16 @@ class WhileExtractor(FrontExtractorOp):
@classmethod @classmethod
def extract(cls, loop_node): def extract(cls, loop_node):
Loop.update_node_stat(loop_node, {}) Loop.update_node_stat(loop_node, {})
loop_name = loop_node.soft_get('name', loop_node.id)
# check that required body and condition functions exist in the graph library # check that required body and condition functions exist in the graph library
main_graph = loop_node.graph main_graph = loop_node.graph
body_graph_name = loop_node.pb.attr['body'].func.name body_graph_proto = get_graph_proto(main_graph, 'body', loop_node)
cond_graph_name = loop_node.pb.attr['cond'].func.name cond_graph_proto = get_graph_proto(main_graph, 'cond', loop_node)
assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \
'by node with name "{}".'.format(loop_name)
library_graph = main_graph.graph['library']
assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \ body_graph = create_internal_graph(main_graph)
'that is required by node ' \
'with name "{}".'.format(body_graph_name, loop_name)
body_graph_proto = library_graph[body_graph_name]
assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
'that is required by node ' \
'with name "{}".'.format(cond_graph_name, loop_name)
cond_graph_proto = library_graph[cond_graph_name]
body_graph = Graph()
# fill the body graph
for attr_key in main_graph.graph.keys():
if attr_key != 'library':
body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key])
else:
# it is sufficient to have a link to the library
body_graph.graph['library'] = main_graph.graph['library']
loop_node['body'] = body_graph loop_node['body'] = body_graph
# create Parameter nodes for the body graph # create Parameter nodes for the body graph
body_parameters = [] body_parameters, body_parameter_names = convert_graph_inputs_to_parameters(body_graph, body_graph_proto)
body_parameter_names = []
for idx, pb_node in enumerate(body_graph_proto['input_arg']):
param_id = body_graph.unique_id(pb_node.name)
body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
parameter_node = Node(body_graph, pb_node.name)
Parameter.update_node_stat(parameter_node,
{'data_type': tf_dtype_extractor(pb_node.type),
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
)
body_parameters.append(parameter_node)
body_parameter_names.append(param_id)
# update the loop body graph with the body function graph # update the loop body graph with the body function graph
body_results = [] body_results = []

View File

@ -138,8 +138,8 @@ def graph_or_sub_graph_has_nhwc_ops(graph: Graph):
NHWC_conv_detected = True NHWC_conv_detected = True
break break
# for the Loop node we need to check that the body does not contain marker ops as well if node.has('sub_graphs'):
if node.op == 'Loop': for sub_graph_name in node['sub_graphs']:
NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.body) NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.soft_get(sub_graph_name))
# TODO check for If op when it is implemented
return NHWC_conv_detected return NHWC_conv_detected

View File

@ -0,0 +1,328 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.middle.passes.infer import partial_infer
from mo.ops.op import Op
class If(Op):
"""
If operation is an operation which has an input with condition which defines what sub-graph "then" or "else" to be
executed.
"""
op = 'If'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
base_attrs = {
'type': self.op,
'op': self.op,
'then_graph': None, # an Graph object with a "then" body sub-graph (condition is True)
'else_graph': None, # an Graph object with a "else" body sub-graph (condition is False)
'sub_graphs': ['then_graph', 'else_graph'], # built-in attribute with all sub-graphs
'version': 'opset8',
'infer': self.infer,
'type_infer': self.type_infer,
}
base_attrs.update(attrs)
super().__init__(graph, base_attrs, attrs)
def port_map_attrs(self):
return [
'external_port_id',
'internal_layer_id'
]
@staticmethod
def connect_body_input(if_node: Node, condition: bool, if_input_port_idx: int, body_parameter: Node):
"""
Update the specified body parameter and connect it with If input
:param if_node: the If node
:param condition: the boolean defining a condition (then/else) graph to add connect the body
:param if_input_port_idx: the input port index to connect
:param body_parameter: the body parameter node to connect
:return: None
"""
assert if_node.soft_get('op') == 'If'
assert body_parameter.soft_get('op') == 'Parameter'
sub_graph = if_node.then_graph if condition else if_node.else_graph
assert body_parameter.id in sub_graph
body_parameter['input_id'] = if_input_port_idx
@staticmethod
def connect_body_output(if_node: Node, condition: bool, if_output_port_idx: int, internal_result: Node):
"""
Update the specified output port and connect it with If output
:param if_node: the If node
:param condition: the boolean defining a condition (then/else) graph to add connect the body
:param if_output_port_idx: the output port index to connect
:param internal_result: the body Result node to connect
:return: None
"""
assert if_node.soft_get('op') == 'If'
assert internal_result.soft_get('op') == 'Result'
sub_graph = if_node.then_graph if condition else if_node.else_graph
assert internal_result.id in sub_graph
internal_result['output_id'] = if_output_port_idx
@staticmethod
def update_body_parameters_type(if_node: Node, condition: bool):
"""
Update the data type for If body Parameter nodes based on data type of the outer graph nodes producing data
for them.
:param if_node: The If node
:param condition: the boolean defining a condition (then/else) graph
:return: None
"""
assert if_node.soft_get('type') == 'If'
subgraph = if_node.then_graph if condition else if_node.else_graph
for node in subgraph.get_op_nodes():
if node.has('input_id'):
assert node.soft_get('type') == 'Parameter'
input_port_id = node['input_id']
input_type = if_node.in_port(input_port_id).get_data_type()
node.data_type = input_type
log.debug('Updated data type for the body node with name "{}" with value {}'
.format(node.name, node.data_type))
@staticmethod
def update_body_parameters_shape(if_node: Node, condition: bool):
"""
Update shape for If body parameters.
:param if_node: The If node
:param condition: the boolean defining a condition (then/else) graph to add connect the body
:return: None
"""
subgraph = if_node.then_graph if condition else if_node.else_graph
for node in subgraph.get_op_nodes():
if node.has('input_id'):
assert node.soft_get('type') == 'Parameter'
input_port_id = node['input_id']
input_shape = if_node.in_port(input_port_id).data.get_shape()
if node.soft_get('shape', None) is None:
node['shape'] = None
node.shape = input_shape.copy()
log.debug('Updated shape for the body node with name "{}" with value {}'
.format(node.soft_get('name', node.soft_get('id')), node.shape))
@staticmethod
def results_mapping_and_finding_fake_outputs(output_nodes_in_subgraph, branch_name, outputs_mapping):
"""
This method checked result nodes in subgraph and set map between output from If operation and internal subgraph
result. Also This method return True if internal graph has fake results.
:param output_nodes_in_subgraph: Result node with attribute 'output_id'
:param branch_name: name of subgraph
:param outputs_mapping: map between If operation output ID and subgraph results
:return: True if all results of subgraph are empty tensors
"""
graph_contain_fake_outputs = True
for output_node in output_nodes_in_subgraph:
assert output_node.soft_get('type') == 'Result'
port_id = output_node['output_id']
assert port_id in outputs_mapping.keys(), 'Incorrect mapping then_graph outputs with {0} outputs! ' \
'Can\'t find port with ID {1} in If operation.' \
.format(output_node.name, port_id)
outputs_mapping[port_id][branch_name] = output_node
out_node_shape = output_node.in_port(0).data.get_shape()
graph_contain_fake_outputs = graph_contain_fake_outputs and np.any(out_node_shape == 0)
return graph_contain_fake_outputs
@staticmethod
def update_if_output_ports_shape(if_node: Node):
"""
Update shape and values for If output ports.
:param if_node: The If node to update output ports and shapes
:return: None
"""
then_outputs = [node for node in if_node.then_graph.get_op_nodes() if node.has('output_id')]
else_outputs = [node for node in if_node.else_graph.get_op_nodes() if node.has('output_id')]
outputs_mapping = {}
outputs_number = len(if_node.out_ports())
if outputs_number == 0 and len(if_node.out_ports(control_flow=True)) != 0:
# Some models have if with control flow outputs.
# These shape inference for such ifs
# TODO: need to rethink and redo support for control flow edges in if operation
for node in if_node.out_nodes(control_flow=True).values():
node.shape = int64_array([])
return
for port_id in if_node.out_ports().keys():
outputs_mapping[port_id] = {}
# variables then_contains_fake_outputs/else_contains_fake_outputs contains True value
# if all outputs from then_body/else_body have shape [0]. It means then_body/else_body does not return data
# and further shape_inference for this branch is not possible.
# TODO: exclude support fake_outputs from this code when we will support shape_inference with empty tensors
then_contains_fake_outputs = \
If.results_mapping_and_finding_fake_outputs(then_outputs, 'then_graph', outputs_mapping)
else_contains_fake_outputs = \
If.results_mapping_and_finding_fake_outputs(else_outputs, 'else_graph', outputs_mapping)
# use_then_shape is True when else_body or when both bodies do not return data. If use_then_shape is True If's
# outputs will have the same shapes as then_body results
use_then_shape = else_contains_fake_outputs or not then_contains_fake_outputs
for port_id in outputs_mapping:
then_else_nodes = outputs_mapping[port_id]
assert 'then_graph' in then_else_nodes.keys(), 'then_graph does not connect with If.out_port[{0}] ' \
'in {1} node!'.format(port_id, if_node.name)
assert 'else_graph' in then_else_nodes.keys(), 'else_graph does not connect with If.out_port[{0}] ' \
'in {1} node!'.format(port_id, if_node.name)
then_shape = then_else_nodes['then_graph'].in_port(0).data.get_shape()
else_shape = then_else_nodes['else_graph'].in_port(0).data.get_shape()
if not (then_shape == else_shape).all():
log.debug("If node {0} has dynamic output [{1}] because output shape from then_graph is {2} and "
"else_graph {3}".format(if_node.name, port_id, then_shape, else_shape))
if_node.out_port(port_id).data.set_shape(then_shape if use_then_shape else else_shape)
@staticmethod
def update_if_output_ports_type(if_node: Node):
"""
Update types for If output ports.
:param if_node: The If node to update output ports and types
:return: None
"""
then_outputs = [node for node in if_node.then_graph.get_op_nodes() if node.has('output_id')]
else_outputs = [node for node in if_node.else_graph.get_op_nodes() if node.has('output_id')]
outputs_mapping = {}
outputs_number = len(if_node.out_ports())
assert outputs_number == len(then_outputs), 'Incorrect number outputs in then_graph of If with"' \
'name {0}! then_graph must has {1} outputs' \
.format(if_node.name, outputs_number)
assert outputs_number == len(else_outputs), 'Incorrect number outputs in else_graph of If with"' \
'name {0}! else_graph must has {1} outputs' \
.format(if_node.name, outputs_number)
for port_id in if_node.out_ports().keys():
outputs_mapping[port_id] = {}
port_ids = outputs_mapping.keys()
for then_output_node in then_outputs:
assert then_output_node.soft_get('type') == 'Result'
port_id = then_output_node['output_id']
assert port_id in port_ids, 'Incorrect mapping then_graph outputs with {0} outputs! ' \
'Can\'t find port with ID {1} in If operation.' \
.format(then_output_node.name, port_id)
outputs_mapping[port_id]['then_graph'] = then_output_node
for else_output_node in else_outputs:
assert else_output_node.soft_get('type') == 'Result'
port_id = else_output_node['output_id']
assert port_id in port_ids, 'Incorrect mapping then_graph outputs with {0} outputs! ' \
'Can\'t find port with ID {1} in If operation.' \
.format(else_output_node.name, port_id)
outputs_mapping[port_id]['else_graph'] = else_output_node
for port_id in outputs_mapping:
then_else_nodes = outputs_mapping[port_id]
assert 'then_graph' in then_else_nodes.keys(), 'then_graph does not connect with If.out_port[{0}] ' \
'in {1} node!'.format(port_id, if_node.name)
assert 'else_graph' in then_else_nodes.keys(), 'else_graph does not connect with If.out_port[{0}] ' \
'in {1} node!'.format(port_id, if_node.name)
then_type = then_else_nodes['then_graph'].in_port(0).get_data_type()
else_type = then_else_nodes['else_graph'].in_port(0).get_data_type()
assert then_type == else_type, 'Cannot get type for if.out_port[{0}]! ' \
'Types in then_graph and else_graph are not equal!'.format(port_id)
if_node.out_port(port_id).set_data_type(then_type)
@staticmethod
def re_numerate_internal_id_and_get_if_id(if_node):
"""
This method is called before IR generation. This method sets internal_layer_id.
:param if_node: The If node where is necessary to set internal_layer_id in bodies.
:return: if_node
"""
then_graph_nodes = if_node.then_graph.nodes()
for idx in range(len(if_node.then_graph.get_op_nodes())):
then_graph_nodes[idx]['internal_layer_id'] = idx
else_graph_nodes = if_node.else_graph.nodes()
for idx in range(len(if_node.else_graph.get_op_nodes())):
else_graph_nodes[idx]['internal_layer_id'] = idx
return if_node.node
def substitute_ie_attrs(self, new_attrs: dict):
"""
Replace standard list of attribute in layer/data by attributes
delivered by backend_attrs
"""
port_map_attrs = self.port_map_attrs()
new_attrs.update({
'IE': [(
'layer',
[('id', lambda node: self.re_numerate_internal_id_and_get_if_id(node)), 'name', 'type', 'version'],
[
'@ports',
('then_port_map', [], [
('@list', lambda node: self.generate_port_map(node, True, 'in'),
('input', port_map_attrs, [])),
('@list', lambda node: self.generate_port_map(node, True, 'out'),
('output', port_map_attrs, [])),
]),
('else_port_map', [], [
('@list', lambda node: self.generate_port_map(node, False, 'in'),
('input', port_map_attrs, [])),
('@list', lambda node: self.generate_port_map(node, False, 'out'),
('output', port_map_attrs, [])),
]),
('then_body', [], [('@network', 'then_graph')]),
('else_body', [], [('@network', 'else_graph')]),
])]
})
@staticmethod
def generate_port_map(if_node: Node, condition: bool, dir: str):
"""
Extract port_map attributes from if_node and its subgraphs attributes.
:param if_node: The If node
:param condition: the boolean defining a condition (then/else) graph
:param dir: the str value defining type (for inputs or for putputs) of port_map
:return: port_map -> list of dictionaries with to values(external_port_id or internal_layer_id)
"""
port_map = []
subgraph = if_node.then_graph if condition else if_node.else_graph
name_of_connection = 'input_id' if dir == 'in' else 'output_id'
for internal_node in subgraph.get_op_nodes():
if internal_node.has(name_of_connection):
port_map.append({'external_port_id': internal_node[name_of_connection],
'internal_layer_id': internal_node['internal_layer_id']})
return port_map
@staticmethod
def infer(if_node: Node):
If.update_body_parameters_shape(if_node, True)
If.update_body_parameters_shape(if_node, False)
partial_infer(if_node.then_graph)
partial_infer(if_node.else_graph)
If.update_if_output_ports_shape(if_node)
@staticmethod
def type_infer(if_node: Node):
from mo.middle.passes.infer import type_infer
If.update_body_parameters_type(if_node, True)
If.update_body_parameters_type(if_node, False)
type_infer(if_node.then_graph)
type_infer(if_node.else_graph)
If.update_if_output_ports_type(if_node)

View File

@ -0,0 +1,108 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import copy
from extensions.ops.parameter import Parameter
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors, create_tf_edge
from mo.front.tf.extractors.utils import tf_dtype_extractor
from mo.graph.graph import Graph, Node, add_opoutput
from mo.ops.op import PermuteAttrs
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
body_parameter_names: list, body_results: list):
"""
Updates the loop body graph with a sub-graph (for body or condition functions)
:param body_graph: a loop body graph to be updated
:param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
:param body_parameter_names: a (unchanged) list of parameters in the loop body graph
:param body_results: a list of Result nodes that is extended with a list from a sub-graph
"""
# create a map from a node name in original model to a name in a loop body graph assuming
# that names in the original model are unique
# initially, the map contains names for parameters that are common for the body and condition graphs
map_original_name = {}
for idx, pb_node in enumerate(subgraph_proto['input_arg']):
map_original_name[pb_node.name] = body_parameter_names[idx]
# walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
for pb_node in subgraph_proto['node_def']:
# create an NX node
id = body_graph.unique_id(pb_node.name)
map_original_name[pb_node.name] = id
body_graph.add_node(id, pb=pb_node, kind='op')
if hasattr(body_graph, 'op_names_statistic') and hasattr(pb_node, 'op'):
body_graph.op_names_statistic[pb_node.op] += 1
# add incoming edges based on data_nodes_map
for dst_port, inp in enumerate(pb_node.input):
orig_src_id = inp.split(":")[0]
# TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
# skip control flow dependency
if orig_src_id[0] == '^':
continue
src_id = map_original_name[orig_src_id]
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
assert (body_graph.has_node(src_id))
body_graph.add_edges_from([create_tf_edge(src_id + ":" + str(src_port), id, dst_port)])
# create Result nodes in the loop body graph
for output in subgraph_proto['output_arg']:
output_name = subgraph_proto['ret'][output.name]
orig_src_id = output_name.split(":")[0]
src_id = map_original_name[orig_src_id]
src_port = 0 if len(output_name.split(":")) == 1 \
else int(output_name.split(":")[-1])
assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format(
src_id)
body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
return True
def get_graph_proto(external_graph: Graph, graph_id: str, node_with_graph: Node):
graph_name = node_with_graph.pb.attr[graph_id].func.name
node_name = node_with_graph.soft_get('name', node_with_graph.id)
assert 'library' in external_graph.graph, 'The graph does not contain a library that is required ' \
'by node with name "{}".'.format(node_name)
library_graph = external_graph.graph['library']
assert graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
'that is required by node ' \
'with name "{}".'.format(graph_name, node_name)
return library_graph[graph_name]
def create_internal_graph(external_graph: Graph):
internal_graph = Graph()
# fill the body graph
for attr_key in external_graph.graph.keys():
if attr_key != 'library':
internal_graph.graph[attr_key] = copy.deepcopy(external_graph.graph[attr_key])
else:
# it is sufficient to have a link to the library
internal_graph.graph['library'] = external_graph.graph['library']
return internal_graph
def convert_graph_inputs_to_parameters(internal_graph, internal_graph_proto):
# create Parameter nodes for the body graph
body_parameters = []
body_parameter_names = []
for idx, pb_node in enumerate(internal_graph_proto['input_arg']):
param_id = internal_graph.unique_id(pb_node.name)
internal_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
parameter_node = Node(internal_graph, pb_node.name)
Parameter.update_node_stat(parameter_node,
{'data_type': tf_dtype_extractor(pb_node.type),
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
)
body_parameters.append(parameter_node)
body_parameter_names.append(param_id)
return body_parameters, body_parameter_names

View File

@ -194,7 +194,8 @@ def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_
# do not run the type inference in sub-graphs. It will be called automatically as part of the type inference of # do not run the type inference in sub-graphs. It will be called automatically as part of the type inference of
# the TensorIterator nodes # the TensorIterator nodes
type_infer(graph) type_infer(graph)
RemoveUselessConvert().find_and_replace_pattern(graph)
for_graph_and_each_sub_graph_recursively(graph, RemoveUselessConvert().find_and_replace_pattern)
ResultRename().find_and_replace_pattern(graph) ResultRename().find_and_replace_pattern(graph)

View File

@ -240,34 +240,8 @@ class IREngine(object):
xml_body_child = list(layer.iterfind('body')) xml_body_child = list(layer.iterfind('body'))
assert len(xml_body_child) == 1 assert len(xml_body_child) == 1
body_ir = IREngine(path_to_xml=None, body_ir, input_port_map, output_port_map, input_layers = \
path_to_bin=self.path_to_bin, self.__read_subgraph(layer, layer_attrs, xml_body_child, 'port_map')
xml_tree=ElementTree(xml_body_child[0]))
self.graph.graph['hashes'].update(body_ir.graph.graph['hashes'])
# Find port_map section and take an input_port_map & output_port_map
xml_port_map = list(layer.iterfind('port_map'))
if not len(xml_port_map) == 1:
log.warning("TensorIterator body won\'t be compared due to missing port_map section!")
continue
xml_port_map = xml_port_map[0]
input_layers = []
input_port_map = []
output_port_map = []
for port in xml_port_map:
if port.tag == 'input':
if 'internal_layer_id' not in port.attrib:
log.warning("internal_layer_id attrib not found in input section")
else:
input_layers.append(Node(body_ir.graph, port.attrib['internal_layer_id']))
input_port_map.append(self.__normalize_attrs(port.attrib))
elif port.tag == 'output':
if 'internal_layer_id' not in port.attrib:
log.warning("internal_layer_id attrib not found in output section")
else:
output_port_map.append(self.__normalize_attrs(port.attrib))
body_ir.input_node = input_layers[0] body_ir.input_node = input_layers[0]
layer_attrs.update({'body': body_ir}) layer_attrs.update({'body': body_ir})
@ -287,6 +261,12 @@ class IREngine(object):
layer_attrs.update({'back_edges': back_edges}) layer_attrs.update({'back_edges': back_edges})
elif attr.tag == 'then_body' or attr.tag == 'else_body':
assert layer.attrib['type'] == 'If', "Incorrect IR! The operation {0}" \
" has sub-graphs for If operation"
layer_attrs = self.__read_if(layer, layer_attrs)
continue
return layer_id, layer_attrs return layer_id, layer_attrs
@staticmethod @staticmethod
@ -405,3 +385,54 @@ class IREngine(object):
if not isinstance(other, IREngine): if not isinstance(other, IREngine):
raise AttributeError("IREngine can be compared only with IREngine object type") raise AttributeError("IREngine can be compared only with IREngine object type")
return self.compare(other)[0] return self.compare(other)[0]
def __read_subgraph(self, layer, layer_attrs, body_child, port_map_name):
body_ir = IREngine(path_to_xml=None,
path_to_bin=self.path_to_bin,
xml_tree=ElementTree(body_child[0]))
self.graph.graph['hashes'].update(body_ir.graph.graph['hashes'])
xml_port_map = list(layer.iterfind(port_map_name))
assert not len(xml_port_map) != 1, "If then_body won\'t be compared due to missing {1} section in node {0}! " \
.format(layer_attrs['name'], port_map_name)
xml_port_map = xml_port_map[0]
input_layers = []
input_port_map = []
output_port_map = []
for port in xml_port_map:
if port.tag == 'input':
if 'internal_layer_id' not in port.attrib:
log.warning("internal_layer_id attrib not found in input section")
else:
input_layers.append(Node(body_ir.graph, port.attrib['internal_layer_id']))
input_port_map.append(self.__normalize_attrs(port.attrib))
elif port.tag == 'output':
if 'internal_layer_id' not in port.attrib:
log.warning("internal_layer_id attrib not found in output section")
else:
output_port_map.append(self.__normalize_attrs(port.attrib))
return body_ir, input_port_map, output_port_map, input_layers
def __read_if(self, layer, layer_attrs):
xml_then_body_child = list(layer.iterfind('then_body'))
xml_else_body_child = list(layer.iterfind('else_body'))
assert len(xml_then_body_child) == 1 and len(xml_else_body_child) == 1, "If operation has only one subgraph"
then_body_ir, then_input_port_map, then_output_port_map, _ = \
self.__read_subgraph(layer, layer_attrs, xml_then_body_child, 'then_port_map')
layer_attrs.update({'then_graph': then_body_ir})
layer_attrs.update({'then_input_port_map': then_input_port_map})
layer_attrs.update({'then_output_port_map': then_output_port_map})
else_body_ir, else_input_port_map, else_output_port_map, _ = \
self.__read_subgraph(layer, layer_attrs, xml_else_body_child, 'else_port_map')
layer_attrs.update({'else_graph': else_body_ir})
layer_attrs.update({'else_input_port_map': else_input_port_map})
layer_attrs.update({'else_output_port_map': else_output_port_map})
return layer_attrs

View File

@ -0,0 +1,41 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from mo.utils.graph import Node
from mo.utils.ir_reader.extender import Extender
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops
class IfExtender(Extender):
op = 'If'
@staticmethod
def set_input_output_id(subgraph, input_port_map, output_port_map):
for node in subgraph.get_op_nodes():
if not node.has_valid('id'):
continue
node_id = int(node.soft_get('id'))
for if_input_mapping_elem in input_port_map:
if node_id == if_input_mapping_elem['internal_layer_id']:
node['input_id'] = if_input_mapping_elem['external_port_id']
for if_out_mapping_elem in output_port_map:
if node_id == if_out_mapping_elem['internal_layer_id']:
node['output_id'] = if_out_mapping_elem['external_port_id']
@staticmethod
def extend(op: Node):
assert op.has('then_graph'), 'There is no "then_body" attribute in the If op {}.'.format(op.name)
assert op.has('else_graph'), 'There is no "else_body" attribute in the If op {}.'.format(op.name)
# Now op.body is an IREngine, we need to replace it with IREngine.graph
op.then_graph.graph.graph['cmd_params'] = op.graph.graph['cmd_params']
op.then_graph.graph.graph['ir_version'] = op.graph.graph['ir_version']
op.then_graph.graph.name = op.name + '/then_body'
op.else_graph.graph.graph['cmd_params'] = op.graph.graph['cmd_params']
op.else_graph.graph.graph['ir_version'] = op.graph.graph['ir_version']
op.else_graph.graph.name = op.name + '/else_body'
op.then_graph = copy_graph_with_ops(op.then_graph.graph)
op.else_graph = copy_graph_with_ops(op.else_graph.graph)
IfExtender.set_input_output_id(op.then_graph, op.then_input_port_map, op.then_output_port_map)
IfExtender.set_input_output_id(op.else_graph, op.else_input_port_map, op.else_output_port_map)

View File

@ -44,7 +44,7 @@ class TFLoaderTest(unittest.TestCase):
# create fake Loop operation # create fake Loop operation
nodes = { nodes = {
**regular_op('input', {'op': 'Parameter'}), **regular_op('input', {'op': 'Parameter'}),
**regular_op('loop', {'op': 'Loop', 'body': body_graph}), **regular_op('loop', {'op': 'Loop', 'body': body_graph, 'sub_graphs': ['body']}),
**result('result'), **result('result'),
} }
edges = [*connect_front('input', '0:loop'), edges = [*connect_front('input', '0:loop'),
@ -65,4 +65,3 @@ class TFLoaderTest(unittest.TestCase):
def test_no_convolution_main_and_sub_graph(self): def test_no_convolution_main_and_sub_graph(self):
self.assertFalse(graph_or_sub_graph_has_nhwc_ops(self.build_loop_graph(self.build_parameter_result_graph()))) self.assertFalse(graph_or_sub_graph_has_nhwc_ops(self.build_loop_graph(self.build_parameter_result_graph())))

View File

@ -0,0 +1,127 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import numpy.testing as npt
import unittest
from extensions.ops.If import If
from mo.ops.shape import Shape
from extensions.ops.elementwise import Add, Mul
from extensions.ops.identity import Identity
from extensions.ops.parameter import Parameter
from mo.front.common.partial_infer.concat import concat_infer
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.middle.passes.infer import partial_infer
from mo.ops.concat import Concat
from mo.ops.eltwise import eltwise_infer
from mo.ops.result import Result
from unit_tests.utils.graph import build_graph_with_edge_attrs, build_graph
from unit_tests.utils.graph import regular_op_with_empty_data, connect, const, result, valued_const_with_data, \
regular_op, empty_data
class TestIf(unittest.TestCase):
def test_simple_shape_inf(self):
then_graph_nodes = {**regular_op_with_empty_data('param_1', {'type': 'Parameter', 'kind': 'op', 'input_id': 1,
'shape': None, 'infer': Parameter.infer}),
**regular_op_with_empty_data('param_2', {'type': 'Parameter', 'kind': 'op', 'input_id': 2,
'shape': None, 'infer': Parameter.infer}),
**regular_op_with_empty_data('add', {'type': 'Add', 'kind': 'op', 'op': 'Add',
'infer': lambda node: eltwise_infer(node,
Add.operation)}),
**regular_op_with_empty_data('mul', {'type': 'Mul', 'kind': 'op', 'op': 'Mul',
'infer': lambda node: eltwise_infer(node,
Mul.operation)}),
**regular_op_with_empty_data('res1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
'infer': lambda x: 0, 'output_id': 0}),
**regular_op_with_empty_data('res2', {'kind': 'op', 'type': 'Result', 'op': 'Result',
'infer': lambda x: 0, 'output_id': 1})}
then_graph_edges = [*connect('param_1', '0:add'),
*connect('param_2', '1:add'),
*connect('param_1', '1:mul'),
*connect('param_2', '0:mul'),
*connect('add', 'res1'),
*connect('mul', 'res2'),
]
else_graph_nodes = {**regular_op_with_empty_data('param_1', {'type': 'Parameter', 'kind': 'op', 'input_id': 1,
'shape': None, 'infer': Parameter.infer}),
**regular_op_with_empty_data('param_2', {'type': 'Parameter', 'kind': 'op', 'input_id': 3,
'shape': None, 'infer': Parameter.infer}),
**regular_op_with_empty_data('identity',
{'kind': 'op', 'op': 'Identity', 'infer': Identity.infer}),
**regular_op_with_empty_data('identity_1',
{'kind': 'op', 'op': 'Identity', 'infer': Identity.infer}),
**regular_op_with_empty_data('res1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
'infer': lambda x: 0, 'output_id': 0}),
**regular_op_with_empty_data('res2', {'kind': 'op', 'type': 'Result', 'op': 'Result',
'infer': lambda x: 0, 'output_id': 1})}
else_graph_edges = [*connect('param_1', 'identity'),
*connect('param_2', 'identity_1'),
*connect('identity_1', 'res2'),
*connect('identity', 'res1'), ]
then_graph = build_graph_with_edge_attrs(then_graph_nodes, then_graph_edges)
else_graph = build_graph_with_edge_attrs(else_graph_nodes, else_graph_edges)
external_graph_nodes = {
**valued_const_with_data('cond', np.array([True], dtype=np.bool)),
**valued_const_with_data('input_2', int64_array([3, 2, 1])),
**valued_const_with_data('input_1', int64_array([1, 2, 3])),
**valued_const_with_data('input_3', int64_array([8, 4])),
**regular_op('if', {'kind': 'op', 'op': 'If', 'then_graph': then_graph,
'else_graph': else_graph, 'infer': If.infer}),
**empty_data('if_d_1'),
**empty_data('if_d_2'),
**result('res_1'),
**result('res_2')}
external_graph_edges = [*connect('cond', '0:if'),
*connect('input_1', '1:if'),
*connect('input_2', '2:if'),
*connect('input_3', '3:if'),
('if', 'if_d_1', {'out': 0}),
('if', 'if_d_2', {'out': 1}),
('if_d_1', 'res_1'),
('if_d_2', 'res_2')]
graph = build_graph(external_graph_nodes, external_graph_edges)
graph.stage = 'middle'
partial_infer(graph)
res_1 = Node(graph, 'res_1')
res_2 = Node(graph, 'res_2')
npt.assert_array_equal(res_1.in_port(0).data.get_shape(), int64_array([3]))
npt.assert_array_equal(res_2.in_port(0).data.get_shape(), int64_array([3]))
def test_fake_results(self):
then_graph_nodes = {**valued_const_with_data('fake_const', int64_array(0)),
**regular_op_with_empty_data('shapeof',
{'kind': 'op', 'type': 'ShapeOf', 'op': 'ShapeOf', 'infer': Shape.infer,
'output_type': np.int64}),
**regular_op_with_empty_data('res_1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
'infer': lambda x: 0, 'output_id': 0})}
then_graph_edges = [*connect('fake_const', 'shapeof'),
*connect('shapeof', 'res_1'),
]
else_graph_nodes = {**regular_op_with_empty_data('param_1', {'type': 'Parameter', 'kind': 'op', 'input_id': 1,
'shape': None, 'infer': Parameter.infer}),
**regular_op_with_empty_data('res_1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
'infer': lambda x: 0, 'output_id': 0})}
else_graph_edges = [*connect('param_1', 'res_1')]
then_graph = build_graph_with_edge_attrs(then_graph_nodes, then_graph_edges)
else_graph = build_graph_with_edge_attrs(else_graph_nodes, else_graph_edges)
external_graph_nodes = {
**valued_const_with_data('cond', np.array([True], dtype=np.bool)),
**valued_const_with_data('input_1', int64_array([[1, 2, 3], [3, 2, 3]])),
**regular_op_with_empty_data('if', {'kind': 'op', 'op': 'If', 'then_graph': then_graph,
'else_graph': else_graph, 'infer': If.infer}),
**result('res_1')}
external_graph_edges = [*connect('cond', '0:if'),
*connect('input_1', '1:if'),
*connect('if', 'res_1')]
graph = build_graph(external_graph_nodes, external_graph_edges)
graph.stage = 'middle'
partial_infer(graph)
res_1 = Node(graph, 'res_1')
npt.assert_array_equal(res_1.in_port(0).data.get_shape(), int64_array([2,3]))