Mo implementation for If with tf extractor (#6662)
* Add tf2.x impl for If * Fix ir_engine * Fix opset * Fix BOM file * Added new test * Fix comments * Add subgraph_utils * Fix comments * Fix transform * code refactoring * Fix description * rewrite support for empty tensor in if * added onnx extractor * delete onnx_if * fix bug with fake_outputs * Fix test * Fix control_flow and fix commentaries * create method results_mapping_and_finding_fake_outputs(output_nodes_in_subgraph,
This commit is contained in:
parent
506148cff6
commit
38022c4cd6
@ -436,6 +436,7 @@ extensions/front/tf/GatherTree_ext.py
|
|||||||
extensions/front/tf/GNMT_DynamicSequenceLengths.py
|
extensions/front/tf/GNMT_DynamicSequenceLengths.py
|
||||||
extensions/front/tf/identity_ext.py
|
extensions/front/tf/identity_ext.py
|
||||||
extensions/front/tf/identityN_to_identity.py
|
extensions/front/tf/identityN_to_identity.py
|
||||||
|
extensions/front/tf/if_ext.py
|
||||||
extensions/front/tf/InterpolateTransposes.py
|
extensions/front/tf/InterpolateTransposes.py
|
||||||
extensions/front/tf/IteratorGetNext_ext.py
|
extensions/front/tf/IteratorGetNext_ext.py
|
||||||
extensions/front/tf/log_softmax_ext.py
|
extensions/front/tf/log_softmax_ext.py
|
||||||
@ -701,6 +702,7 @@ extensions/ops/GRU.py
|
|||||||
extensions/ops/GRUCell.py
|
extensions/ops/GRUCell.py
|
||||||
extensions/ops/hard_sigmoid.py
|
extensions/ops/hard_sigmoid.py
|
||||||
extensions/ops/identity.py
|
extensions/ops/identity.py
|
||||||
|
extensions/ops/If.py
|
||||||
extensions/ops/instance_normalization.py
|
extensions/ops/instance_normalization.py
|
||||||
extensions/ops/interp.py
|
extensions/ops/interp.py
|
||||||
extensions/ops/interpolate.py
|
extensions/ops/interpolate.py
|
||||||
@ -927,6 +929,7 @@ mo/front/tf/extractors/native_tf.py
|
|||||||
mo/front/tf/extractors/pack.py
|
mo/front/tf/extractors/pack.py
|
||||||
mo/front/tf/extractors/random_uniform.py
|
mo/front/tf/extractors/random_uniform.py
|
||||||
mo/front/tf/extractors/strided_slice.py
|
mo/front/tf/extractors/strided_slice.py
|
||||||
|
mo/front/tf/extractors/subgraph_utils.py
|
||||||
mo/front/tf/extractors/utils.py
|
mo/front/tf/extractors/utils.py
|
||||||
mo/front/tf/graph_utils.py
|
mo/front/tf/graph_utils.py
|
||||||
mo/front/tf/loader.py
|
mo/front/tf/loader.py
|
||||||
@ -1050,6 +1053,7 @@ mo/utils/ir_reader/extenders/experimental_extender.py
|
|||||||
mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py
|
mo/utils/ir_reader/extenders/ExtractImagePatches_extender.py
|
||||||
mo/utils/ir_reader/extenders/fakequantize_extender.py
|
mo/utils/ir_reader/extenders/fakequantize_extender.py
|
||||||
mo/utils/ir_reader/extenders/GRUCell_extender.py
|
mo/utils/ir_reader/extenders/GRUCell_extender.py
|
||||||
|
mo/utils/ir_reader/extenders/if_extender.py
|
||||||
mo/utils/ir_reader/extenders/interpolate_extender.py
|
mo/utils/ir_reader/extenders/interpolate_extender.py
|
||||||
mo/utils/ir_reader/extenders/loop_extender.py
|
mo/utils/ir_reader/extenders/loop_extender.py
|
||||||
mo/utils/ir_reader/extenders/LSTMCell_extender.py
|
mo/utils/ir_reader/extenders/LSTMCell_extender.py
|
||||||
|
69
model-optimizer/extensions/front/tf/if_ext.py
Normal file
69
model-optimizer/extensions/front/tf/if_ext.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from extensions.ops.If import If
|
||||||
|
from extensions.ops.parameter import Parameter
|
||||||
|
from mo.front.common.register_custom_ops import check_for_duplicates
|
||||||
|
from mo.front.extractor import FrontExtractorOp, extract_node_attrs
|
||||||
|
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors
|
||||||
|
from mo.front.tf.extractors.subgraph_utils import update_body_graph, convert_graph_inputs_to_parameters, \
|
||||||
|
get_graph_proto, create_internal_graph
|
||||||
|
from mo.graph.graph import Node, Graph
|
||||||
|
|
||||||
|
|
||||||
|
def extract_if(cls, if_node: Node):
|
||||||
|
If.update_node_stat(if_node, {})
|
||||||
|
|
||||||
|
# check that required body and condition functions exist in the graph library
|
||||||
|
main_graph = if_node.graph
|
||||||
|
then_graph_proto = get_graph_proto(main_graph, 'then_branch', if_node)
|
||||||
|
else_graph_proto = get_graph_proto(main_graph, 'else_branch', if_node)
|
||||||
|
|
||||||
|
then_graph = create_internal_graph(main_graph)
|
||||||
|
if_node['then_graph'] = then_graph
|
||||||
|
|
||||||
|
else_graph = create_internal_graph(main_graph)
|
||||||
|
if_node['else_graph'] = else_graph
|
||||||
|
|
||||||
|
# create Parameter nodes for the then/else graphs
|
||||||
|
for input_index, (body_graph, body_graph_proto) in enumerate(zip((then_graph, else_graph), (then_graph_proto,
|
||||||
|
else_graph_proto))):
|
||||||
|
|
||||||
|
body_parameters, body_parameter_names = convert_graph_inputs_to_parameters(body_graph, body_graph_proto)
|
||||||
|
|
||||||
|
# update the If body graph with the body function graph
|
||||||
|
body_results = []
|
||||||
|
update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results)
|
||||||
|
|
||||||
|
body_graph.stage = 'front'
|
||||||
|
|
||||||
|
# connect external input ports with body parameter nodes except input with condition
|
||||||
|
for idx in range(0, len(body_parameters)):
|
||||||
|
If.connect_body_input(if_node, not input_index, idx + 1, body_parameters[idx])
|
||||||
|
|
||||||
|
# connect body outputs with If operation output ports
|
||||||
|
for idx in range(len(body_results)):
|
||||||
|
If.connect_body_output(if_node, not input_index, idx, body_results[idx])
|
||||||
|
|
||||||
|
# run function to parse body nodes attributes similar to the main graph
|
||||||
|
extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
|
||||||
|
|
||||||
|
return cls.enabled
|
||||||
|
|
||||||
|
|
||||||
|
class IfExtractor(FrontExtractorOp):
|
||||||
|
op = 'If'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract(cls, if_node: Node):
|
||||||
|
return extract_if(cls, if_node)
|
||||||
|
|
||||||
|
|
||||||
|
class StatelessIfExtractor(FrontExtractorOp):
|
||||||
|
op = 'StatelessIf'
|
||||||
|
enabled = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract(cls, if_node: Node):
|
||||||
|
return extract_if(cls, if_node)
|
@ -1,68 +1,14 @@
|
|||||||
# Copyright (C) 2018-2021 Intel Corporation
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
from extensions.ops.loop import Loop
|
from extensions.ops.loop import Loop
|
||||||
from extensions.ops.parameter import Parameter
|
from extensions.ops.parameter import Parameter
|
||||||
from mo.front.common.register_custom_ops import check_for_duplicates
|
from mo.front.common.register_custom_ops import check_for_duplicates
|
||||||
from mo.front.extractor import extract_node_attrs, FrontExtractorOp
|
from mo.front.extractor import extract_node_attrs, FrontExtractorOp
|
||||||
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors, create_tf_edge
|
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors, create_tf_edge
|
||||||
from mo.front.tf.extractors.utils import tf_dtype_extractor
|
from mo.front.tf.extractors.subgraph_utils import update_body_graph, convert_graph_inputs_to_parameters, \
|
||||||
|
get_graph_proto, create_internal_graph
|
||||||
from mo.graph.graph import add_opoutput, Graph, Node
|
from mo.graph.graph import add_opoutput, Graph, Node
|
||||||
from mo.ops.op import PermuteAttrs
|
|
||||||
|
|
||||||
|
|
||||||
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
|
|
||||||
body_parameter_names: list, body_results: list):
|
|
||||||
"""
|
|
||||||
Updates the loop body graph with a sub-graph (for body or condition functions)
|
|
||||||
:param body_graph: a loop body graph to be updated
|
|
||||||
:param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
|
|
||||||
:param body_parameter_names: a (unchanged) list of parameters in the loop body graph
|
|
||||||
:param body_results: a list of Result nodes that is extended with a list from a sub-graph
|
|
||||||
"""
|
|
||||||
# create a map from a node name in original model to a name in a loop body graph assuming
|
|
||||||
# that names in the original model are unique
|
|
||||||
# initially, the map contains names for parameters that are common for the body and condition graphs
|
|
||||||
map_original_name = {}
|
|
||||||
for idx, pb_node in enumerate(subgraph_proto['input_arg']):
|
|
||||||
map_original_name[pb_node.name] = body_parameter_names[idx]
|
|
||||||
|
|
||||||
# walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
|
|
||||||
for pb_node in subgraph_proto['node_def']:
|
|
||||||
# create an NX node
|
|
||||||
id = body_graph.unique_id(pb_node.name)
|
|
||||||
map_original_name[pb_node.name] = id
|
|
||||||
body_graph.add_node(id, pb=pb_node, kind='op')
|
|
||||||
if hasattr(body_graph, 'op_names_statistic') and hasattr(pb_node, 'op'):
|
|
||||||
body_graph.op_names_statistic[pb_node.op] += 1
|
|
||||||
|
|
||||||
# add incoming edges based on data_nodes_map
|
|
||||||
for dst_port, inp in enumerate(pb_node.input):
|
|
||||||
orig_src_id = inp.split(":")[0]
|
|
||||||
|
|
||||||
# TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
|
|
||||||
# skip control flow dependency
|
|
||||||
if orig_src_id[0] == '^':
|
|
||||||
continue
|
|
||||||
|
|
||||||
src_id = map_original_name[orig_src_id]
|
|
||||||
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
|
|
||||||
assert (body_graph.has_node(src_id))
|
|
||||||
|
|
||||||
body_graph.add_edges_from([create_tf_edge(src_id + ":" + str(src_port), id, dst_port)])
|
|
||||||
|
|
||||||
# create Result nodes in the loop body graph
|
|
||||||
for output in subgraph_proto['output_arg']:
|
|
||||||
output_name = subgraph_proto['ret'][output.name]
|
|
||||||
orig_src_id = output_name.split(":")[0]
|
|
||||||
src_id = map_original_name[orig_src_id]
|
|
||||||
src_port = 0 if len(output_name.split(":")) == 1\
|
|
||||||
else int(output_name.split(":")[-1])
|
|
||||||
assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format(
|
|
||||||
src_id)
|
|
||||||
body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
|
|
||||||
|
|
||||||
|
|
||||||
class WhileExtractor(FrontExtractorOp):
|
class WhileExtractor(FrontExtractorOp):
|
||||||
@ -77,49 +23,16 @@ class WhileExtractor(FrontExtractorOp):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def extract(cls, loop_node):
|
def extract(cls, loop_node):
|
||||||
Loop.update_node_stat(loop_node, {})
|
Loop.update_node_stat(loop_node, {})
|
||||||
loop_name = loop_node.soft_get('name', loop_node.id)
|
|
||||||
|
|
||||||
# check that required body and condition functions exist in the graph library
|
# check that required body and condition functions exist in the graph library
|
||||||
main_graph = loop_node.graph
|
main_graph = loop_node.graph
|
||||||
body_graph_name = loop_node.pb.attr['body'].func.name
|
body_graph_proto = get_graph_proto(main_graph, 'body', loop_node)
|
||||||
cond_graph_name = loop_node.pb.attr['cond'].func.name
|
cond_graph_proto = get_graph_proto(main_graph, 'cond', loop_node)
|
||||||
assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \
|
|
||||||
'by node with name "{}".'.format(loop_name)
|
|
||||||
library_graph = main_graph.graph['library']
|
|
||||||
|
|
||||||
assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
|
body_graph = create_internal_graph(main_graph)
|
||||||
'that is required by node ' \
|
|
||||||
'with name "{}".'.format(body_graph_name, loop_name)
|
|
||||||
body_graph_proto = library_graph[body_graph_name]
|
|
||||||
|
|
||||||
assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
|
|
||||||
'that is required by node ' \
|
|
||||||
'with name "{}".'.format(cond_graph_name, loop_name)
|
|
||||||
cond_graph_proto = library_graph[cond_graph_name]
|
|
||||||
|
|
||||||
body_graph = Graph()
|
|
||||||
# fill the body graph
|
|
||||||
for attr_key in main_graph.graph.keys():
|
|
||||||
if attr_key != 'library':
|
|
||||||
body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key])
|
|
||||||
else:
|
|
||||||
# it is sufficient to have a link to the library
|
|
||||||
body_graph.graph['library'] = main_graph.graph['library']
|
|
||||||
loop_node['body'] = body_graph
|
loop_node['body'] = body_graph
|
||||||
|
|
||||||
# create Parameter nodes for the body graph
|
# create Parameter nodes for the body graph
|
||||||
body_parameters = []
|
body_parameters, body_parameter_names = convert_graph_inputs_to_parameters(body_graph, body_graph_proto)
|
||||||
body_parameter_names = []
|
|
||||||
for idx, pb_node in enumerate(body_graph_proto['input_arg']):
|
|
||||||
param_id = body_graph.unique_id(pb_node.name)
|
|
||||||
body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
|
|
||||||
parameter_node = Node(body_graph, pb_node.name)
|
|
||||||
Parameter.update_node_stat(parameter_node,
|
|
||||||
{'data_type': tf_dtype_extractor(pb_node.type),
|
|
||||||
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
|
|
||||||
)
|
|
||||||
body_parameters.append(parameter_node)
|
|
||||||
body_parameter_names.append(param_id)
|
|
||||||
|
|
||||||
# update the loop body graph with the body function graph
|
# update the loop body graph with the body function graph
|
||||||
body_results = []
|
body_results = []
|
||||||
@ -172,7 +85,7 @@ class WhileExtractor(FrontExtractorOp):
|
|||||||
Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx])
|
Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx])
|
||||||
|
|
||||||
# connect body outputs with Loop operation output ports except the execution condition result
|
# connect body outputs with Loop operation output ports except the execution condition result
|
||||||
for idx in range(len(body_results)-1):
|
for idx in range(len(body_results) - 1):
|
||||||
Loop.connect_body_output(loop_node, idx, body_results[idx])
|
Loop.connect_body_output(loop_node, idx, body_results[idx])
|
||||||
|
|
||||||
# run function to parse body nodes attributes similar to the main graph
|
# run function to parse body nodes attributes similar to the main graph
|
||||||
|
@ -138,8 +138,8 @@ def graph_or_sub_graph_has_nhwc_ops(graph: Graph):
|
|||||||
NHWC_conv_detected = True
|
NHWC_conv_detected = True
|
||||||
break
|
break
|
||||||
|
|
||||||
# for the Loop node we need to check that the body does not contain marker ops as well
|
if node.has('sub_graphs'):
|
||||||
if node.op == 'Loop':
|
for sub_graph_name in node['sub_graphs']:
|
||||||
NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.body)
|
NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.soft_get(sub_graph_name))
|
||||||
# TODO check for If op when it is implemented
|
|
||||||
return NHWC_conv_detected
|
return NHWC_conv_detected
|
||||||
|
328
model-optimizer/extensions/ops/If.py
Normal file
328
model-optimizer/extensions/ops/If.py
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging as log
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.graph.graph import Node, Graph
|
||||||
|
from mo.middle.passes.infer import partial_infer
|
||||||
|
from mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
|
class If(Op):
|
||||||
|
"""
|
||||||
|
If operation is an operation which has an input with condition which defines what sub-graph "then" or "else" to be
|
||||||
|
executed.
|
||||||
|
"""
|
||||||
|
op = 'If'
|
||||||
|
enabled = False
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
|
base_attrs = {
|
||||||
|
'type': self.op,
|
||||||
|
'op': self.op,
|
||||||
|
'then_graph': None, # an Graph object with a "then" body sub-graph (condition is True)
|
||||||
|
'else_graph': None, # an Graph object with a "else" body sub-graph (condition is False)
|
||||||
|
'sub_graphs': ['then_graph', 'else_graph'], # built-in attribute with all sub-graphs
|
||||||
|
'version': 'opset8',
|
||||||
|
'infer': self.infer,
|
||||||
|
'type_infer': self.type_infer,
|
||||||
|
}
|
||||||
|
base_attrs.update(attrs)
|
||||||
|
super().__init__(graph, base_attrs, attrs)
|
||||||
|
|
||||||
|
def port_map_attrs(self):
|
||||||
|
return [
|
||||||
|
'external_port_id',
|
||||||
|
'internal_layer_id'
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def connect_body_input(if_node: Node, condition: bool, if_input_port_idx: int, body_parameter: Node):
|
||||||
|
"""
|
||||||
|
Update the specified body parameter and connect it with If input
|
||||||
|
|
||||||
|
:param if_node: the If node
|
||||||
|
:param condition: the boolean defining a condition (then/else) graph to add connect the body
|
||||||
|
:param if_input_port_idx: the input port index to connect
|
||||||
|
:param body_parameter: the body parameter node to connect
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
assert if_node.soft_get('op') == 'If'
|
||||||
|
assert body_parameter.soft_get('op') == 'Parameter'
|
||||||
|
sub_graph = if_node.then_graph if condition else if_node.else_graph
|
||||||
|
assert body_parameter.id in sub_graph
|
||||||
|
body_parameter['input_id'] = if_input_port_idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def connect_body_output(if_node: Node, condition: bool, if_output_port_idx: int, internal_result: Node):
|
||||||
|
"""
|
||||||
|
Update the specified output port and connect it with If output
|
||||||
|
|
||||||
|
:param if_node: the If node
|
||||||
|
:param condition: the boolean defining a condition (then/else) graph to add connect the body
|
||||||
|
:param if_output_port_idx: the output port index to connect
|
||||||
|
:param internal_result: the body Result node to connect
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
assert if_node.soft_get('op') == 'If'
|
||||||
|
assert internal_result.soft_get('op') == 'Result'
|
||||||
|
sub_graph = if_node.then_graph if condition else if_node.else_graph
|
||||||
|
assert internal_result.id in sub_graph
|
||||||
|
internal_result['output_id'] = if_output_port_idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_body_parameters_type(if_node: Node, condition: bool):
|
||||||
|
"""
|
||||||
|
Update the data type for If body Parameter nodes based on data type of the outer graph nodes producing data
|
||||||
|
for them.
|
||||||
|
|
||||||
|
:param if_node: The If node
|
||||||
|
:param condition: the boolean defining a condition (then/else) graph
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
assert if_node.soft_get('type') == 'If'
|
||||||
|
|
||||||
|
subgraph = if_node.then_graph if condition else if_node.else_graph
|
||||||
|
for node in subgraph.get_op_nodes():
|
||||||
|
if node.has('input_id'):
|
||||||
|
assert node.soft_get('type') == 'Parameter'
|
||||||
|
input_port_id = node['input_id']
|
||||||
|
input_type = if_node.in_port(input_port_id).get_data_type()
|
||||||
|
node.data_type = input_type
|
||||||
|
log.debug('Updated data type for the body node with name "{}" with value {}'
|
||||||
|
.format(node.name, node.data_type))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_body_parameters_shape(if_node: Node, condition: bool):
|
||||||
|
"""
|
||||||
|
Update shape for If body parameters.
|
||||||
|
|
||||||
|
:param if_node: The If node
|
||||||
|
:param condition: the boolean defining a condition (then/else) graph to add connect the body
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
subgraph = if_node.then_graph if condition else if_node.else_graph
|
||||||
|
for node in subgraph.get_op_nodes():
|
||||||
|
if node.has('input_id'):
|
||||||
|
assert node.soft_get('type') == 'Parameter'
|
||||||
|
input_port_id = node['input_id']
|
||||||
|
input_shape = if_node.in_port(input_port_id).data.get_shape()
|
||||||
|
if node.soft_get('shape', None) is None:
|
||||||
|
node['shape'] = None
|
||||||
|
node.shape = input_shape.copy()
|
||||||
|
log.debug('Updated shape for the body node with name "{}" with value {}'
|
||||||
|
.format(node.soft_get('name', node.soft_get('id')), node.shape))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def results_mapping_and_finding_fake_outputs(output_nodes_in_subgraph, branch_name, outputs_mapping):
|
||||||
|
"""
|
||||||
|
This method checked result nodes in subgraph and set map between output from If operation and internal subgraph
|
||||||
|
result. Also This method return True if internal graph has fake results.
|
||||||
|
|
||||||
|
:param output_nodes_in_subgraph: Result node with attribute 'output_id'
|
||||||
|
:param branch_name: name of subgraph
|
||||||
|
:param outputs_mapping: map between If operation output ID and subgraph results
|
||||||
|
|
||||||
|
:return: True if all results of subgraph are empty tensors
|
||||||
|
"""
|
||||||
|
graph_contain_fake_outputs = True
|
||||||
|
|
||||||
|
for output_node in output_nodes_in_subgraph:
|
||||||
|
assert output_node.soft_get('type') == 'Result'
|
||||||
|
port_id = output_node['output_id']
|
||||||
|
assert port_id in outputs_mapping.keys(), 'Incorrect mapping then_graph outputs with {0} outputs! ' \
|
||||||
|
'Can\'t find port with ID {1} in If operation.' \
|
||||||
|
.format(output_node.name, port_id)
|
||||||
|
outputs_mapping[port_id][branch_name] = output_node
|
||||||
|
out_node_shape = output_node.in_port(0).data.get_shape()
|
||||||
|
graph_contain_fake_outputs = graph_contain_fake_outputs and np.any(out_node_shape == 0)
|
||||||
|
return graph_contain_fake_outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_if_output_ports_shape(if_node: Node):
|
||||||
|
"""
|
||||||
|
Update shape and values for If output ports.
|
||||||
|
|
||||||
|
:param if_node: The If node to update output ports and shapes
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
then_outputs = [node for node in if_node.then_graph.get_op_nodes() if node.has('output_id')]
|
||||||
|
else_outputs = [node for node in if_node.else_graph.get_op_nodes() if node.has('output_id')]
|
||||||
|
outputs_mapping = {}
|
||||||
|
outputs_number = len(if_node.out_ports())
|
||||||
|
|
||||||
|
if outputs_number == 0 and len(if_node.out_ports(control_flow=True)) != 0:
|
||||||
|
# Some models have if with control flow outputs.
|
||||||
|
# These shape inference for such ifs
|
||||||
|
# TODO: need to rethink and redo support for control flow edges in if operation
|
||||||
|
for node in if_node.out_nodes(control_flow=True).values():
|
||||||
|
node.shape = int64_array([])
|
||||||
|
return
|
||||||
|
|
||||||
|
for port_id in if_node.out_ports().keys():
|
||||||
|
outputs_mapping[port_id] = {}
|
||||||
|
|
||||||
|
# variables then_contains_fake_outputs/else_contains_fake_outputs contains True value
|
||||||
|
# if all outputs from then_body/else_body have shape [0]. It means then_body/else_body does not return data
|
||||||
|
# and further shape_inference for this branch is not possible.
|
||||||
|
# TODO: exclude support fake_outputs from this code when we will support shape_inference with empty tensors
|
||||||
|
|
||||||
|
then_contains_fake_outputs = \
|
||||||
|
If.results_mapping_and_finding_fake_outputs(then_outputs, 'then_graph', outputs_mapping)
|
||||||
|
else_contains_fake_outputs = \
|
||||||
|
If.results_mapping_and_finding_fake_outputs(else_outputs, 'else_graph', outputs_mapping)
|
||||||
|
|
||||||
|
# use_then_shape is True when else_body or when both bodies do not return data. If use_then_shape is True If's
|
||||||
|
# outputs will have the same shapes as then_body results
|
||||||
|
use_then_shape = else_contains_fake_outputs or not then_contains_fake_outputs
|
||||||
|
|
||||||
|
for port_id in outputs_mapping:
|
||||||
|
then_else_nodes = outputs_mapping[port_id]
|
||||||
|
assert 'then_graph' in then_else_nodes.keys(), 'then_graph does not connect with If.out_port[{0}] ' \
|
||||||
|
'in {1} node!'.format(port_id, if_node.name)
|
||||||
|
assert 'else_graph' in then_else_nodes.keys(), 'else_graph does not connect with If.out_port[{0}] ' \
|
||||||
|
'in {1} node!'.format(port_id, if_node.name)
|
||||||
|
|
||||||
|
then_shape = then_else_nodes['then_graph'].in_port(0).data.get_shape()
|
||||||
|
else_shape = then_else_nodes['else_graph'].in_port(0).data.get_shape()
|
||||||
|
|
||||||
|
if not (then_shape == else_shape).all():
|
||||||
|
log.debug("If node {0} has dynamic output [{1}] because output shape from then_graph is {2} and "
|
||||||
|
"else_graph {3}".format(if_node.name, port_id, then_shape, else_shape))
|
||||||
|
if_node.out_port(port_id).data.set_shape(then_shape if use_then_shape else else_shape)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_if_output_ports_type(if_node: Node):
|
||||||
|
"""
|
||||||
|
Update types for If output ports.
|
||||||
|
|
||||||
|
:param if_node: The If node to update output ports and types
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
then_outputs = [node for node in if_node.then_graph.get_op_nodes() if node.has('output_id')]
|
||||||
|
else_outputs = [node for node in if_node.else_graph.get_op_nodes() if node.has('output_id')]
|
||||||
|
outputs_mapping = {}
|
||||||
|
outputs_number = len(if_node.out_ports())
|
||||||
|
assert outputs_number == len(then_outputs), 'Incorrect number outputs in then_graph of If with"' \
|
||||||
|
'name {0}! then_graph must has {1} outputs' \
|
||||||
|
.format(if_node.name, outputs_number)
|
||||||
|
assert outputs_number == len(else_outputs), 'Incorrect number outputs in else_graph of If with"' \
|
||||||
|
'name {0}! else_graph must has {1} outputs' \
|
||||||
|
.format(if_node.name, outputs_number)
|
||||||
|
for port_id in if_node.out_ports().keys():
|
||||||
|
outputs_mapping[port_id] = {}
|
||||||
|
port_ids = outputs_mapping.keys()
|
||||||
|
for then_output_node in then_outputs:
|
||||||
|
assert then_output_node.soft_get('type') == 'Result'
|
||||||
|
port_id = then_output_node['output_id']
|
||||||
|
assert port_id in port_ids, 'Incorrect mapping then_graph outputs with {0} outputs! ' \
|
||||||
|
'Can\'t find port with ID {1} in If operation.' \
|
||||||
|
.format(then_output_node.name, port_id)
|
||||||
|
outputs_mapping[port_id]['then_graph'] = then_output_node
|
||||||
|
|
||||||
|
for else_output_node in else_outputs:
|
||||||
|
assert else_output_node.soft_get('type') == 'Result'
|
||||||
|
port_id = else_output_node['output_id']
|
||||||
|
assert port_id in port_ids, 'Incorrect mapping then_graph outputs with {0} outputs! ' \
|
||||||
|
'Can\'t find port with ID {1} in If operation.' \
|
||||||
|
.format(else_output_node.name, port_id)
|
||||||
|
outputs_mapping[port_id]['else_graph'] = else_output_node
|
||||||
|
|
||||||
|
for port_id in outputs_mapping:
|
||||||
|
then_else_nodes = outputs_mapping[port_id]
|
||||||
|
assert 'then_graph' in then_else_nodes.keys(), 'then_graph does not connect with If.out_port[{0}] ' \
|
||||||
|
'in {1} node!'.format(port_id, if_node.name)
|
||||||
|
assert 'else_graph' in then_else_nodes.keys(), 'else_graph does not connect with If.out_port[{0}] ' \
|
||||||
|
'in {1} node!'.format(port_id, if_node.name)
|
||||||
|
then_type = then_else_nodes['then_graph'].in_port(0).get_data_type()
|
||||||
|
else_type = then_else_nodes['else_graph'].in_port(0).get_data_type()
|
||||||
|
assert then_type == else_type, 'Cannot get type for if.out_port[{0}]! ' \
|
||||||
|
'Types in then_graph and else_graph are not equal!'.format(port_id)
|
||||||
|
if_node.out_port(port_id).set_data_type(then_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def re_numerate_internal_id_and_get_if_id(if_node):
|
||||||
|
"""
|
||||||
|
This method is called before IR generation. This method sets internal_layer_id.
|
||||||
|
|
||||||
|
:param if_node: The If node where is necessary to set internal_layer_id in bodies.
|
||||||
|
:return: if_node
|
||||||
|
"""
|
||||||
|
then_graph_nodes = if_node.then_graph.nodes()
|
||||||
|
for idx in range(len(if_node.then_graph.get_op_nodes())):
|
||||||
|
then_graph_nodes[idx]['internal_layer_id'] = idx
|
||||||
|
else_graph_nodes = if_node.else_graph.nodes()
|
||||||
|
for idx in range(len(if_node.else_graph.get_op_nodes())):
|
||||||
|
else_graph_nodes[idx]['internal_layer_id'] = idx
|
||||||
|
return if_node.node
|
||||||
|
|
||||||
|
def substitute_ie_attrs(self, new_attrs: dict):
|
||||||
|
"""
|
||||||
|
Replace standard list of attribute in layer/data by attributes
|
||||||
|
delivered by backend_attrs
|
||||||
|
"""
|
||||||
|
|
||||||
|
port_map_attrs = self.port_map_attrs()
|
||||||
|
new_attrs.update({
|
||||||
|
'IE': [(
|
||||||
|
'layer',
|
||||||
|
[('id', lambda node: self.re_numerate_internal_id_and_get_if_id(node)), 'name', 'type', 'version'],
|
||||||
|
[
|
||||||
|
'@ports',
|
||||||
|
('then_port_map', [], [
|
||||||
|
('@list', lambda node: self.generate_port_map(node, True, 'in'),
|
||||||
|
('input', port_map_attrs, [])),
|
||||||
|
('@list', lambda node: self.generate_port_map(node, True, 'out'),
|
||||||
|
('output', port_map_attrs, [])),
|
||||||
|
]),
|
||||||
|
('else_port_map', [], [
|
||||||
|
('@list', lambda node: self.generate_port_map(node, False, 'in'),
|
||||||
|
('input', port_map_attrs, [])),
|
||||||
|
('@list', lambda node: self.generate_port_map(node, False, 'out'),
|
||||||
|
('output', port_map_attrs, [])),
|
||||||
|
]),
|
||||||
|
('then_body', [], [('@network', 'then_graph')]),
|
||||||
|
('else_body', [], [('@network', 'else_graph')]),
|
||||||
|
])]
|
||||||
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_port_map(if_node: Node, condition: bool, dir: str):
|
||||||
|
"""
|
||||||
|
Extract port_map attributes from if_node and its subgraphs attributes.
|
||||||
|
|
||||||
|
:param if_node: The If node
|
||||||
|
:param condition: the boolean defining a condition (then/else) graph
|
||||||
|
:param dir: the str value defining type (for inputs or for putputs) of port_map
|
||||||
|
:return: port_map -> list of dictionaries with to values(external_port_id or internal_layer_id)
|
||||||
|
"""
|
||||||
|
port_map = []
|
||||||
|
subgraph = if_node.then_graph if condition else if_node.else_graph
|
||||||
|
name_of_connection = 'input_id' if dir == 'in' else 'output_id'
|
||||||
|
|
||||||
|
for internal_node in subgraph.get_op_nodes():
|
||||||
|
if internal_node.has(name_of_connection):
|
||||||
|
port_map.append({'external_port_id': internal_node[name_of_connection],
|
||||||
|
'internal_layer_id': internal_node['internal_layer_id']})
|
||||||
|
|
||||||
|
return port_map
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer(if_node: Node):
|
||||||
|
If.update_body_parameters_shape(if_node, True)
|
||||||
|
If.update_body_parameters_shape(if_node, False)
|
||||||
|
partial_infer(if_node.then_graph)
|
||||||
|
partial_infer(if_node.else_graph)
|
||||||
|
If.update_if_output_ports_shape(if_node)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def type_infer(if_node: Node):
|
||||||
|
from mo.middle.passes.infer import type_infer
|
||||||
|
If.update_body_parameters_type(if_node, True)
|
||||||
|
If.update_body_parameters_type(if_node, False)
|
||||||
|
type_infer(if_node.then_graph)
|
||||||
|
type_infer(if_node.else_graph)
|
||||||
|
If.update_if_output_ports_type(if_node)
|
108
model-optimizer/mo/front/tf/extractors/subgraph_utils.py
Normal file
108
model-optimizer/mo/front/tf/extractors/subgraph_utils.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from extensions.ops.parameter import Parameter
|
||||||
|
from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors, create_tf_edge
|
||||||
|
from mo.front.tf.extractors.utils import tf_dtype_extractor
|
||||||
|
from mo.graph.graph import Graph, Node, add_opoutput
|
||||||
|
from mo.ops.op import PermuteAttrs
|
||||||
|
|
||||||
|
|
||||||
|
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
|
||||||
|
body_parameter_names: list, body_results: list):
|
||||||
|
"""
|
||||||
|
Updates the loop body graph with a sub-graph (for body or condition functions)
|
||||||
|
:param body_graph: a loop body graph to be updated
|
||||||
|
:param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
|
||||||
|
:param body_parameter_names: a (unchanged) list of parameters in the loop body graph
|
||||||
|
:param body_results: a list of Result nodes that is extended with a list from a sub-graph
|
||||||
|
"""
|
||||||
|
# create a map from a node name in original model to a name in a loop body graph assuming
|
||||||
|
# that names in the original model are unique
|
||||||
|
# initially, the map contains names for parameters that are common for the body and condition graphs
|
||||||
|
map_original_name = {}
|
||||||
|
for idx, pb_node in enumerate(subgraph_proto['input_arg']):
|
||||||
|
map_original_name[pb_node.name] = body_parameter_names[idx]
|
||||||
|
|
||||||
|
# walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
|
||||||
|
for pb_node in subgraph_proto['node_def']:
|
||||||
|
# create an NX node
|
||||||
|
id = body_graph.unique_id(pb_node.name)
|
||||||
|
map_original_name[pb_node.name] = id
|
||||||
|
body_graph.add_node(id, pb=pb_node, kind='op')
|
||||||
|
if hasattr(body_graph, 'op_names_statistic') and hasattr(pb_node, 'op'):
|
||||||
|
body_graph.op_names_statistic[pb_node.op] += 1
|
||||||
|
|
||||||
|
# add incoming edges based on data_nodes_map
|
||||||
|
for dst_port, inp in enumerate(pb_node.input):
|
||||||
|
orig_src_id = inp.split(":")[0]
|
||||||
|
|
||||||
|
# TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
|
||||||
|
# skip control flow dependency
|
||||||
|
if orig_src_id[0] == '^':
|
||||||
|
continue
|
||||||
|
|
||||||
|
src_id = map_original_name[orig_src_id]
|
||||||
|
src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1])
|
||||||
|
assert (body_graph.has_node(src_id))
|
||||||
|
|
||||||
|
body_graph.add_edges_from([create_tf_edge(src_id + ":" + str(src_port), id, dst_port)])
|
||||||
|
|
||||||
|
# create Result nodes in the loop body graph
|
||||||
|
for output in subgraph_proto['output_arg']:
|
||||||
|
output_name = subgraph_proto['ret'][output.name]
|
||||||
|
orig_src_id = output_name.split(":")[0]
|
||||||
|
src_id = map_original_name[orig_src_id]
|
||||||
|
src_port = 0 if len(output_name.split(":")) == 1 \
|
||||||
|
else int(output_name.split(":")[-1])
|
||||||
|
assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format(
|
||||||
|
src_id)
|
||||||
|
body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_proto(external_graph: Graph, graph_id: str, node_with_graph: Node):
|
||||||
|
graph_name = node_with_graph.pb.attr[graph_id].func.name
|
||||||
|
node_name = node_with_graph.soft_get('name', node_with_graph.id)
|
||||||
|
|
||||||
|
assert 'library' in external_graph.graph, 'The graph does not contain a library that is required ' \
|
||||||
|
'by node with name "{}".'.format(node_name)
|
||||||
|
|
||||||
|
library_graph = external_graph.graph['library']
|
||||||
|
|
||||||
|
assert graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
|
||||||
|
'that is required by node ' \
|
||||||
|
'with name "{}".'.format(graph_name, node_name)
|
||||||
|
return library_graph[graph_name]
|
||||||
|
|
||||||
|
|
||||||
|
def create_internal_graph(external_graph: Graph):
|
||||||
|
internal_graph = Graph()
|
||||||
|
# fill the body graph
|
||||||
|
for attr_key in external_graph.graph.keys():
|
||||||
|
if attr_key != 'library':
|
||||||
|
internal_graph.graph[attr_key] = copy.deepcopy(external_graph.graph[attr_key])
|
||||||
|
else:
|
||||||
|
# it is sufficient to have a link to the library
|
||||||
|
internal_graph.graph['library'] = external_graph.graph['library']
|
||||||
|
return internal_graph
|
||||||
|
|
||||||
|
|
||||||
|
def convert_graph_inputs_to_parameters(internal_graph, internal_graph_proto):
|
||||||
|
# create Parameter nodes for the body graph
|
||||||
|
body_parameters = []
|
||||||
|
body_parameter_names = []
|
||||||
|
for idx, pb_node in enumerate(internal_graph_proto['input_arg']):
|
||||||
|
param_id = internal_graph.unique_id(pb_node.name)
|
||||||
|
internal_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
|
||||||
|
parameter_node = Node(internal_graph, pb_node.name)
|
||||||
|
Parameter.update_node_stat(parameter_node,
|
||||||
|
{'data_type': tf_dtype_extractor(pb_node.type),
|
||||||
|
'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
|
||||||
|
)
|
||||||
|
body_parameters.append(parameter_node)
|
||||||
|
body_parameter_names.append(param_id)
|
||||||
|
return body_parameters, body_parameter_names
|
@ -194,7 +194,8 @@ def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_
|
|||||||
# do not run the type inference in sub-graphs. It will be called automatically as part of the type inference of
|
# do not run the type inference in sub-graphs. It will be called automatically as part of the type inference of
|
||||||
# the TensorIterator nodes
|
# the TensorIterator nodes
|
||||||
type_infer(graph)
|
type_infer(graph)
|
||||||
RemoveUselessConvert().find_and_replace_pattern(graph)
|
|
||||||
|
for_graph_and_each_sub_graph_recursively(graph, RemoveUselessConvert().find_and_replace_pattern)
|
||||||
|
|
||||||
ResultRename().find_and_replace_pattern(graph)
|
ResultRename().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
@ -240,34 +240,8 @@ class IREngine(object):
|
|||||||
xml_body_child = list(layer.iterfind('body'))
|
xml_body_child = list(layer.iterfind('body'))
|
||||||
assert len(xml_body_child) == 1
|
assert len(xml_body_child) == 1
|
||||||
|
|
||||||
body_ir = IREngine(path_to_xml=None,
|
body_ir, input_port_map, output_port_map, input_layers = \
|
||||||
path_to_bin=self.path_to_bin,
|
self.__read_subgraph(layer, layer_attrs, xml_body_child, 'port_map')
|
||||||
xml_tree=ElementTree(xml_body_child[0]))
|
|
||||||
self.graph.graph['hashes'].update(body_ir.graph.graph['hashes'])
|
|
||||||
|
|
||||||
# Find port_map section and take an input_port_map & output_port_map
|
|
||||||
xml_port_map = list(layer.iterfind('port_map'))
|
|
||||||
if not len(xml_port_map) == 1:
|
|
||||||
log.warning("TensorIterator body won\'t be compared due to missing port_map section!")
|
|
||||||
continue
|
|
||||||
xml_port_map = xml_port_map[0]
|
|
||||||
|
|
||||||
input_layers = []
|
|
||||||
input_port_map = []
|
|
||||||
output_port_map = []
|
|
||||||
|
|
||||||
for port in xml_port_map:
|
|
||||||
if port.tag == 'input':
|
|
||||||
if 'internal_layer_id' not in port.attrib:
|
|
||||||
log.warning("internal_layer_id attrib not found in input section")
|
|
||||||
else:
|
|
||||||
input_layers.append(Node(body_ir.graph, port.attrib['internal_layer_id']))
|
|
||||||
input_port_map.append(self.__normalize_attrs(port.attrib))
|
|
||||||
elif port.tag == 'output':
|
|
||||||
if 'internal_layer_id' not in port.attrib:
|
|
||||||
log.warning("internal_layer_id attrib not found in output section")
|
|
||||||
else:
|
|
||||||
output_port_map.append(self.__normalize_attrs(port.attrib))
|
|
||||||
|
|
||||||
body_ir.input_node = input_layers[0]
|
body_ir.input_node = input_layers[0]
|
||||||
layer_attrs.update({'body': body_ir})
|
layer_attrs.update({'body': body_ir})
|
||||||
@ -287,6 +261,12 @@ class IREngine(object):
|
|||||||
|
|
||||||
layer_attrs.update({'back_edges': back_edges})
|
layer_attrs.update({'back_edges': back_edges})
|
||||||
|
|
||||||
|
elif attr.tag == 'then_body' or attr.tag == 'else_body':
|
||||||
|
assert layer.attrib['type'] == 'If', "Incorrect IR! The operation {0}" \
|
||||||
|
" has sub-graphs for If operation"
|
||||||
|
layer_attrs = self.__read_if(layer, layer_attrs)
|
||||||
|
continue
|
||||||
|
|
||||||
return layer_id, layer_attrs
|
return layer_id, layer_attrs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -405,3 +385,54 @@ class IREngine(object):
|
|||||||
if not isinstance(other, IREngine):
|
if not isinstance(other, IREngine):
|
||||||
raise AttributeError("IREngine can be compared only with IREngine object type")
|
raise AttributeError("IREngine can be compared only with IREngine object type")
|
||||||
return self.compare(other)[0]
|
return self.compare(other)[0]
|
||||||
|
|
||||||
|
def __read_subgraph(self, layer, layer_attrs, body_child, port_map_name):
|
||||||
|
body_ir = IREngine(path_to_xml=None,
|
||||||
|
path_to_bin=self.path_to_bin,
|
||||||
|
xml_tree=ElementTree(body_child[0]))
|
||||||
|
|
||||||
|
self.graph.graph['hashes'].update(body_ir.graph.graph['hashes'])
|
||||||
|
|
||||||
|
xml_port_map = list(layer.iterfind(port_map_name))
|
||||||
|
assert not len(xml_port_map) != 1, "If then_body won\'t be compared due to missing {1} section in node {0}! " \
|
||||||
|
.format(layer_attrs['name'], port_map_name)
|
||||||
|
xml_port_map = xml_port_map[0]
|
||||||
|
|
||||||
|
input_layers = []
|
||||||
|
input_port_map = []
|
||||||
|
output_port_map = []
|
||||||
|
|
||||||
|
for port in xml_port_map:
|
||||||
|
if port.tag == 'input':
|
||||||
|
if 'internal_layer_id' not in port.attrib:
|
||||||
|
log.warning("internal_layer_id attrib not found in input section")
|
||||||
|
else:
|
||||||
|
input_layers.append(Node(body_ir.graph, port.attrib['internal_layer_id']))
|
||||||
|
input_port_map.append(self.__normalize_attrs(port.attrib))
|
||||||
|
elif port.tag == 'output':
|
||||||
|
if 'internal_layer_id' not in port.attrib:
|
||||||
|
log.warning("internal_layer_id attrib not found in output section")
|
||||||
|
else:
|
||||||
|
output_port_map.append(self.__normalize_attrs(port.attrib))
|
||||||
|
|
||||||
|
return body_ir, input_port_map, output_port_map, input_layers
|
||||||
|
|
||||||
|
def __read_if(self, layer, layer_attrs):
|
||||||
|
|
||||||
|
xml_then_body_child = list(layer.iterfind('then_body'))
|
||||||
|
xml_else_body_child = list(layer.iterfind('else_body'))
|
||||||
|
assert len(xml_then_body_child) == 1 and len(xml_else_body_child) == 1, "If operation has only one subgraph"
|
||||||
|
|
||||||
|
then_body_ir, then_input_port_map, then_output_port_map, _ = \
|
||||||
|
self.__read_subgraph(layer, layer_attrs, xml_then_body_child, 'then_port_map')
|
||||||
|
layer_attrs.update({'then_graph': then_body_ir})
|
||||||
|
layer_attrs.update({'then_input_port_map': then_input_port_map})
|
||||||
|
layer_attrs.update({'then_output_port_map': then_output_port_map})
|
||||||
|
|
||||||
|
else_body_ir, else_input_port_map, else_output_port_map, _ = \
|
||||||
|
self.__read_subgraph(layer, layer_attrs, xml_else_body_child, 'else_port_map')
|
||||||
|
layer_attrs.update({'else_graph': else_body_ir})
|
||||||
|
layer_attrs.update({'else_input_port_map': else_input_port_map})
|
||||||
|
layer_attrs.update({'else_output_port_map': else_output_port_map})
|
||||||
|
|
||||||
|
return layer_attrs
|
||||||
|
41
model-optimizer/mo/utils/ir_reader/extenders/if_extender.py
Normal file
41
model-optimizer/mo/utils/ir_reader/extenders/if_extender.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from mo.utils.graph import Node
|
||||||
|
from mo.utils.ir_reader.extender import Extender
|
||||||
|
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops
|
||||||
|
|
||||||
|
|
||||||
|
class IfExtender(Extender):
|
||||||
|
op = 'If'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_input_output_id(subgraph, input_port_map, output_port_map):
|
||||||
|
for node in subgraph.get_op_nodes():
|
||||||
|
if not node.has_valid('id'):
|
||||||
|
continue
|
||||||
|
node_id = int(node.soft_get('id'))
|
||||||
|
for if_input_mapping_elem in input_port_map:
|
||||||
|
if node_id == if_input_mapping_elem['internal_layer_id']:
|
||||||
|
node['input_id'] = if_input_mapping_elem['external_port_id']
|
||||||
|
for if_out_mapping_elem in output_port_map:
|
||||||
|
if node_id == if_out_mapping_elem['internal_layer_id']:
|
||||||
|
node['output_id'] = if_out_mapping_elem['external_port_id']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extend(op: Node):
|
||||||
|
assert op.has('then_graph'), 'There is no "then_body" attribute in the If op {}.'.format(op.name)
|
||||||
|
assert op.has('else_graph'), 'There is no "else_body" attribute in the If op {}.'.format(op.name)
|
||||||
|
# Now op.body is an IREngine, we need to replace it with IREngine.graph
|
||||||
|
op.then_graph.graph.graph['cmd_params'] = op.graph.graph['cmd_params']
|
||||||
|
op.then_graph.graph.graph['ir_version'] = op.graph.graph['ir_version']
|
||||||
|
op.then_graph.graph.name = op.name + '/then_body'
|
||||||
|
|
||||||
|
op.else_graph.graph.graph['cmd_params'] = op.graph.graph['cmd_params']
|
||||||
|
op.else_graph.graph.graph['ir_version'] = op.graph.graph['ir_version']
|
||||||
|
op.else_graph.graph.name = op.name + '/else_body'
|
||||||
|
op.then_graph = copy_graph_with_ops(op.then_graph.graph)
|
||||||
|
op.else_graph = copy_graph_with_ops(op.else_graph.graph)
|
||||||
|
|
||||||
|
IfExtender.set_input_output_id(op.then_graph, op.then_input_port_map, op.then_output_port_map)
|
||||||
|
IfExtender.set_input_output_id(op.else_graph, op.else_input_port_map, op.else_output_port_map)
|
@ -44,7 +44,7 @@ class TFLoaderTest(unittest.TestCase):
|
|||||||
# create fake Loop operation
|
# create fake Loop operation
|
||||||
nodes = {
|
nodes = {
|
||||||
**regular_op('input', {'op': 'Parameter'}),
|
**regular_op('input', {'op': 'Parameter'}),
|
||||||
**regular_op('loop', {'op': 'Loop', 'body': body_graph}),
|
**regular_op('loop', {'op': 'Loop', 'body': body_graph, 'sub_graphs': ['body']}),
|
||||||
**result('result'),
|
**result('result'),
|
||||||
}
|
}
|
||||||
edges = [*connect_front('input', '0:loop'),
|
edges = [*connect_front('input', '0:loop'),
|
||||||
@ -65,4 +65,3 @@ class TFLoaderTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_no_convolution_main_and_sub_graph(self):
|
def test_no_convolution_main_and_sub_graph(self):
|
||||||
self.assertFalse(graph_or_sub_graph_has_nhwc_ops(self.build_loop_graph(self.build_parameter_result_graph())))
|
self.assertFalse(graph_or_sub_graph_has_nhwc_ops(self.build_loop_graph(self.build_parameter_result_graph())))
|
||||||
|
|
||||||
|
127
model-optimizer/unit_tests/extensions/ops/If_test.py
Normal file
127
model-optimizer/unit_tests/extensions/ops/If_test.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.testing as npt
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from extensions.ops.If import If
|
||||||
|
from mo.ops.shape import Shape
|
||||||
|
from extensions.ops.elementwise import Add, Mul
|
||||||
|
from extensions.ops.identity import Identity
|
||||||
|
from extensions.ops.parameter import Parameter
|
||||||
|
from mo.front.common.partial_infer.concat import concat_infer
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.graph.graph import Node, Graph
|
||||||
|
from mo.middle.passes.infer import partial_infer
|
||||||
|
from mo.ops.concat import Concat
|
||||||
|
from mo.ops.eltwise import eltwise_infer
|
||||||
|
from mo.ops.result import Result
|
||||||
|
from unit_tests.utils.graph import build_graph_with_edge_attrs, build_graph
|
||||||
|
from unit_tests.utils.graph import regular_op_with_empty_data, connect, const, result, valued_const_with_data, \
|
||||||
|
regular_op, empty_data
|
||||||
|
|
||||||
|
|
||||||
|
class TestIf(unittest.TestCase):
|
||||||
|
def test_simple_shape_inf(self):
|
||||||
|
then_graph_nodes = {**regular_op_with_empty_data('param_1', {'type': 'Parameter', 'kind': 'op', 'input_id': 1,
|
||||||
|
'shape': None, 'infer': Parameter.infer}),
|
||||||
|
**regular_op_with_empty_data('param_2', {'type': 'Parameter', 'kind': 'op', 'input_id': 2,
|
||||||
|
'shape': None, 'infer': Parameter.infer}),
|
||||||
|
**regular_op_with_empty_data('add', {'type': 'Add', 'kind': 'op', 'op': 'Add',
|
||||||
|
'infer': lambda node: eltwise_infer(node,
|
||||||
|
Add.operation)}),
|
||||||
|
**regular_op_with_empty_data('mul', {'type': 'Mul', 'kind': 'op', 'op': 'Mul',
|
||||||
|
'infer': lambda node: eltwise_infer(node,
|
||||||
|
Mul.operation)}),
|
||||||
|
**regular_op_with_empty_data('res1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
|
||||||
|
'infer': lambda x: 0, 'output_id': 0}),
|
||||||
|
**regular_op_with_empty_data('res2', {'kind': 'op', 'type': 'Result', 'op': 'Result',
|
||||||
|
'infer': lambda x: 0, 'output_id': 1})}
|
||||||
|
then_graph_edges = [*connect('param_1', '0:add'),
|
||||||
|
*connect('param_2', '1:add'),
|
||||||
|
*connect('param_1', '1:mul'),
|
||||||
|
*connect('param_2', '0:mul'),
|
||||||
|
*connect('add', 'res1'),
|
||||||
|
*connect('mul', 'res2'),
|
||||||
|
]
|
||||||
|
|
||||||
|
else_graph_nodes = {**regular_op_with_empty_data('param_1', {'type': 'Parameter', 'kind': 'op', 'input_id': 1,
|
||||||
|
'shape': None, 'infer': Parameter.infer}),
|
||||||
|
**regular_op_with_empty_data('param_2', {'type': 'Parameter', 'kind': 'op', 'input_id': 3,
|
||||||
|
'shape': None, 'infer': Parameter.infer}),
|
||||||
|
**regular_op_with_empty_data('identity',
|
||||||
|
{'kind': 'op', 'op': 'Identity', 'infer': Identity.infer}),
|
||||||
|
**regular_op_with_empty_data('identity_1',
|
||||||
|
{'kind': 'op', 'op': 'Identity', 'infer': Identity.infer}),
|
||||||
|
**regular_op_with_empty_data('res1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
|
||||||
|
'infer': lambda x: 0, 'output_id': 0}),
|
||||||
|
**regular_op_with_empty_data('res2', {'kind': 'op', 'type': 'Result', 'op': 'Result',
|
||||||
|
'infer': lambda x: 0, 'output_id': 1})}
|
||||||
|
else_graph_edges = [*connect('param_1', 'identity'),
|
||||||
|
*connect('param_2', 'identity_1'),
|
||||||
|
*connect('identity_1', 'res2'),
|
||||||
|
*connect('identity', 'res1'), ]
|
||||||
|
then_graph = build_graph_with_edge_attrs(then_graph_nodes, then_graph_edges)
|
||||||
|
else_graph = build_graph_with_edge_attrs(else_graph_nodes, else_graph_edges)
|
||||||
|
external_graph_nodes = {
|
||||||
|
**valued_const_with_data('cond', np.array([True], dtype=np.bool)),
|
||||||
|
**valued_const_with_data('input_2', int64_array([3, 2, 1])),
|
||||||
|
**valued_const_with_data('input_1', int64_array([1, 2, 3])),
|
||||||
|
**valued_const_with_data('input_3', int64_array([8, 4])),
|
||||||
|
**regular_op('if', {'kind': 'op', 'op': 'If', 'then_graph': then_graph,
|
||||||
|
'else_graph': else_graph, 'infer': If.infer}),
|
||||||
|
**empty_data('if_d_1'),
|
||||||
|
**empty_data('if_d_2'),
|
||||||
|
**result('res_1'),
|
||||||
|
**result('res_2')}
|
||||||
|
external_graph_edges = [*connect('cond', '0:if'),
|
||||||
|
*connect('input_1', '1:if'),
|
||||||
|
*connect('input_2', '2:if'),
|
||||||
|
*connect('input_3', '3:if'),
|
||||||
|
('if', 'if_d_1', {'out': 0}),
|
||||||
|
('if', 'if_d_2', {'out': 1}),
|
||||||
|
('if_d_1', 'res_1'),
|
||||||
|
('if_d_2', 'res_2')]
|
||||||
|
|
||||||
|
graph = build_graph(external_graph_nodes, external_graph_edges)
|
||||||
|
graph.stage = 'middle'
|
||||||
|
partial_infer(graph)
|
||||||
|
res_1 = Node(graph, 'res_1')
|
||||||
|
res_2 = Node(graph, 'res_2')
|
||||||
|
npt.assert_array_equal(res_1.in_port(0).data.get_shape(), int64_array([3]))
|
||||||
|
npt.assert_array_equal(res_2.in_port(0).data.get_shape(), int64_array([3]))
|
||||||
|
|
||||||
|
def test_fake_results(self):
|
||||||
|
then_graph_nodes = {**valued_const_with_data('fake_const', int64_array(0)),
|
||||||
|
**regular_op_with_empty_data('shapeof',
|
||||||
|
{'kind': 'op', 'type': 'ShapeOf', 'op': 'ShapeOf', 'infer': Shape.infer,
|
||||||
|
'output_type': np.int64}),
|
||||||
|
**regular_op_with_empty_data('res_1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
|
||||||
|
'infer': lambda x: 0, 'output_id': 0})}
|
||||||
|
then_graph_edges = [*connect('fake_const', 'shapeof'),
|
||||||
|
*connect('shapeof', 'res_1'),
|
||||||
|
]
|
||||||
|
|
||||||
|
else_graph_nodes = {**regular_op_with_empty_data('param_1', {'type': 'Parameter', 'kind': 'op', 'input_id': 1,
|
||||||
|
'shape': None, 'infer': Parameter.infer}),
|
||||||
|
**regular_op_with_empty_data('res_1', {'kind': 'op', 'type': 'Result', 'op': 'Result',
|
||||||
|
'infer': lambda x: 0, 'output_id': 0})}
|
||||||
|
else_graph_edges = [*connect('param_1', 'res_1')]
|
||||||
|
then_graph = build_graph_with_edge_attrs(then_graph_nodes, then_graph_edges)
|
||||||
|
else_graph = build_graph_with_edge_attrs(else_graph_nodes, else_graph_edges)
|
||||||
|
external_graph_nodes = {
|
||||||
|
**valued_const_with_data('cond', np.array([True], dtype=np.bool)),
|
||||||
|
**valued_const_with_data('input_1', int64_array([[1, 2, 3], [3, 2, 3]])),
|
||||||
|
**regular_op_with_empty_data('if', {'kind': 'op', 'op': 'If', 'then_graph': then_graph,
|
||||||
|
'else_graph': else_graph, 'infer': If.infer}),
|
||||||
|
**result('res_1')}
|
||||||
|
external_graph_edges = [*connect('cond', '0:if'),
|
||||||
|
*connect('input_1', '1:if'),
|
||||||
|
*connect('if', 'res_1')]
|
||||||
|
|
||||||
|
graph = build_graph(external_graph_nodes, external_graph_edges)
|
||||||
|
graph.stage = 'middle'
|
||||||
|
partial_infer(graph)
|
||||||
|
res_1 = Node(graph, 'res_1')
|
||||||
|
npt.assert_array_equal(res_1.in_port(0).data.get_shape(), int64_array([2,3]))
|
Loading…
Reference in New Issue
Block a user