Files
openvino/model-optimizer/extensions/middle/TensorIteratorMerge.py
Evgeny Lazarev cabf8d8534 ONNX Loop operation support (#2756)
* Generate TensorIterator without back edges from TensorFlow models

* Added a check in the MarkSubgraphsWithCorrectLayout to not fail when port is not connected

* Updated the 'protobuf2nx' to consume the graph protobuf message

* Cleanup TI from the IRv7 specific code

* Do not run some front transformations recursively

* Draft support for the ONNX Loop operation when 'cond' = True

* LoopToTI transformation changes

* Added draft of Loop operation and parser for ONNX Loop operation body

* Updated Loop body parser + added shape and type infer for the Loop operation

* Fixes for ONNX Loop operation parser

* Moved Loop parsing to Loop op extractor. Added generation of external edges for the Loop body ops

* Added support for ThresholdedRelu using decomposition

* Added support for Min ONNX operation

* Draft fixes for port_map generation for the Loop

* Rename transformation file and fix BOM

* Fixed shape inference for Loop scan outputs (axis is not None)

* Fixed shape inference for ONNX Loop operation

* Refactor checks in the TensorIteratorMerge transformation

* Code refactoring. Enabled commented transformations

* Documentation update for ONNX Loop, ThresholdedRelu and Min

* Fixed typo in the Loop front transformation where execution condition input is connected. Other refactorings

* Fixed in the Loop extractor

* Added printing 'internal_layer_id' attribute in the graph dumper

* Updated calculation of iterations number for the Loop

* Added missing code

* Fixed output port shapes and types generation for Loop operation

* Update function names and variable names in the Loop operation

* Fixed type inference for iteration count input

* Added removal of input/output ports of the Loop if they are not used

* Fixed renumbering Loop operations input/output ports to keep mandatory

* Fixed ThresholdedReluDecomposition transformation

* Updated MO IR Reader to know about Loop operation. But it is still not supported by the MO IR Reader

* Added unit test for Slice op shape infer (reverse the sequence of elements)

* Reverted changes in the ONNX loader function call to protobuf2nx

* Enable Reshape0DToSqueeze transformation recursively

* Refactored Loop operation support implementation

* Changed ThresholdedReluDecomposition to generate Const with shape [1] instead of scalar

* Code style and wording fixes

* Restored accidentally removed 'return' statement in the TI shape infer function

* Fixed comments

* Fixed comment

Co-authored-by: Evgeny Lazarev <elazarev.nnov@gmail.com>
2020-10-27 23:04:43 +03:00

383 lines
17 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import deque
from copy import deepcopy
import numpy as np
from extensions.ops.tensor_iterator import TensorIterator
from mo.graph.graph import Node, Graph, add_opoutput
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.op import Op
from mo.ops.squeeze import Squeeze
from mo.ops.unsqueeze import Unsqueeze
from mo.utils.graph import sub_graph_between_nodes, invert_sub_graph_between_nodes
stop_nodes = ['TensorIteratorInput', 'TensorIteratorOutput', 'TensorIteratorBackEdge', 'TensorIteratorCondition']
def op_type(graph, node_name: str):
node = Node(graph, node_name)
if node.has_valid('kind') and node['kind'] == 'op':
return node['op']
else:
return None
def update_inputs(graph, inputs: list, node_name: str):
node = Node(graph, node_name)
if node.has_valid('kind') and node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
if node_name not in inputs:
inputs.append(node_name)
def reverse_dfs(graph: Graph, node_name: str, stop_nodes: list, inputs: list, visited: set = None):
d = deque()
if visited is None:
visited = set()
visited.add(node_name)
d.appendleft(node_name)
while len(d) != 0:
cur_node = d.popleft()
for in_node_name, _ in graph.in_edges(cur_node):
if in_node_name not in visited:
if op_type(graph, in_node_name) not in stop_nodes:
visited.add(in_node_name)
d.append(in_node_name)
else:
update_inputs(graph, inputs, in_node_name)
def dfs(graph: Graph, node_name: str, stop_nodes: list, visited: set = None):
d = deque()
visited.add(node_name)
d.appendleft(node_name)
while len(d) != 0:
cur_node = d.popleft()
for _, out_node_name in graph.out_edges(cur_node):
if out_node_name not in visited:
if op_type(graph, out_node_name) not in stop_nodes:
visited.add(out_node_name)
d.append(out_node_name)
def get_body(graph, inputs, outputs):
if len(inputs) == 0:
nodes, extra_inputs = invert_sub_graph_between_nodes(
graph,
outputs,
inputs,
lambda node: node.soft_get('op') == 'TensorIteratorInput'
)
else:
nodes, extra_inputs = sub_graph_between_nodes(
graph,
inputs,
outputs,
lambda node: node.soft_get('op') == 'TensorIteratorInput'
)
nodes = list(set(nodes) - set(inputs) - set(outputs) - set(extra_inputs))
return nodes, extra_inputs
class TensorIteratorMerge(MiddleReplacementPattern):
enabled = True
graph_condition = [lambda graph: graph.graph['is_cyclic']]
def run_after(self):
return []
def run_before(self):
return []
@staticmethod
def pattern():
return dict(
nodes=[
('condition', dict(kind='op', op='TensorIteratorCondition')),
],
edges=[],
)
@staticmethod
def replace_pattern(graph, match: dict):
# Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
# and make all checks needed for TensorIteator work
cond_data = match['condition'].out_node(0) if not match['condition'].out_port(0).disconnected() else None
time_data = match['condition'].out_node(1) if len(match['condition'].out_nodes()) >= 1 else None
name = match['condition'].name
back_edges = []
inputs = []
outputs = []
if cond_data is not None:
for node in cond_data.out_nodes():
if node['kind'] == 'op' and node['op'] == 'TensorIteratorBackEdge':
back_edges.append(node.id)
elif node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
inputs.append(node.id)
elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
outputs.append(node.id)
if time_data is not None:
for node in time_data.out_nodes():
if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
inputs.append(node.id)
elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
outputs.append(node.id)
else:
# something goes wrong here
assert False
condition = match['condition']
tensor_sequence_length = condition.in_node(0)
nodes_to_remove = [n.id for n in (condition, cond_data, time_data, tensor_sequence_length) if n is not None]
graph.remove_nodes_from(nodes_to_remove)
body_nodes, extra_inputs = get_body(graph, inputs, outputs)
if cond_data is not None:
body_nodes = list(set(body_nodes) - set([cond_data]))
inputs += extra_inputs
assert all([node in graph.nodes() for node in body_nodes])
inputs = [Node(graph, node) for node in inputs]
outputs = [Node(graph, node) for node in outputs]
back_edges = [Node(graph, node) for node in back_edges]
external_inputs = [
{
'external_data_id': node.in_node(1 if node.has_valid('axis') else 0),
'internal_data_id': node.out_node(0),
'axis': node.axis,
'start': node.start,
'end': node.end,
'stride': node.stride,
'part_size': node.part_size
} for node in inputs]
external_outputs = [
{
'external_data_id': node.out_node(0),
'internal_data_id': node.in_node(1 if node.has_valid('axis') else 0),
'axis': node.axis,
'start': node.start,
'end': node.end,
'stride': node.stride,
'part_size': node.part_size
} for node in outputs]
back_edges_data = [
{
'from_data_id': node.in_node(1),
'to_data_id': node.out_node(0),
'init_data_id': node.in_node(0),
} for node in back_edges
]
body = Graph(name='body')
body.graph = graph.graph
body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
body.add_edges_from(
[(u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes])
graph.remove_nodes_from(
body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs])
internal_id_count = 0
real_back_edges = []
for edge in back_edges_data:
assert edge['from_data_id'].id in body.nodes()
assert edge['to_data_id'].id in body.nodes()
assert edge['init_data_id'].id in body.nodes()
edge['from_data_id'] = Node(body, edge['from_data_id'].id)
edge['to_data_id'] = Node(body, edge['to_data_id'].id)
edge['init_data_id'] = Node(body, edge['init_data_id'].id)
add_opoutput(body, edge['from_data_id'].id, 0, False)
# Assign/reuse ids for the back-edge start; it comes from from_data_id
assert len(edge['from_data_id'].in_nodes()) == 1
# layer id
if not edge['from_data_id'].in_node().has_valid('internal_layer_id'):
edge['from_data_id'].in_node()['internal_layer_id'] = internal_id_count
internal_id_count += 1
edge['from_layer'] = edge['from_data_id'].in_node()['internal_layer_id']
# port id
if 'internal_port_id' not in edge['from_data_id'].in_edge():
edge['from_data_id'].in_edge()['internal_port_id'] = internal_id_count
internal_id_count += 1
edge['from_port'] = edge['from_data_id'].in_edge()['internal_port_id']
# Look at all consumers for a data that ends a back-edge
# For each such consumer, there will be a separate back-edge (and input)
current_real_back_edges = []
for _, consumer, key, edge_attrs in body.out_edges(edge['to_data_id'].id, data=True, keys=True):
real_edge = {}
real_edge.update(edge) # all real back_edges have the same back-edge start
consumer = Node(body, consumer)
if real_edge['to_data_id'].in_node().has_valid('internal_layer_id'):
assert False
real_edge['to_data_id'].out_node()['internal_layer_id'] = \
real_edge['to_data_id'].in_node().internal_layer_id
elif not consumer.has_valid('internal_layer_id'):
consumer['internal_layer_id'] = internal_id_count
internal_id_count += 1
real_edge['to_layer'] = consumer['internal_layer_id']
assert 'internal_port_id' not in edge_attrs
assert len(real_edge['init_data_id'].out_edges()) == 1
assert not 'internal_port_id' in real_edge['init_data_id'].out_edge()
edge_attrs['internal_port_id'] = internal_id_count
internal_id_count += 1
real_edge['to_port'] = edge_attrs['internal_port_id']
real_edge['consumer'] = consumer
real_edge['consumer_key'] = key
real_edge['attrs'] = deepcopy(edge_attrs)
current_real_back_edges.append(real_edge)
# connect initial data node with each consumer providing actual edge attributes
body.add_edges_from([
(
real_edge['init_data_id'].id,
real_edge['consumer'].id,
real_edge['consumer_key'],
real_edge['attrs'])
for real_edge in current_real_back_edges])
body.remove_nodes_from([edge['to_data_id'].id, edge['to_data_id'].in_node().id])
real_back_edges += current_real_back_edges
real_external_inputs = []
for ext_inp in external_inputs:
assert ext_inp['external_data_id'].id not in body.nodes()
assert ext_inp['internal_data_id'].id in body.nodes()
ext_inp['internal_data_id'] = Node(body, ext_inp['internal_data_id'].id)
if ext_inp['axis'] is not None:
# Insert squeezing resize at input port that has partitioning
shape = ext_inp['internal_data_id'].shape.copy()
assert not ext_inp['internal_data_id'].has_valid('value')
new_input_data = Op._create_data_node(body, ext_inp['internal_data_id'].name + '/UnsqueezedInput',
dict(shape=np.insert(shape, ext_inp['axis'], 1)))
reshape_op = Squeeze(body, dict(name=ext_inp['internal_data_id'].name + '/InputSqueeze'))
reshape_dim_data = Const(body, {'name': ext_inp['internal_data_id'].name + '/ReshapeDim',
'value': ext_inp['axis']}).create_node_with_data()
reshape_op.create_node_with_data([new_input_data, reshape_dim_data],
data_nodes=[ext_inp['internal_data_id']])
ext_inp['internal_data_id'] = new_input_data
ext_inp['internal_data_id']['is_input'] = True
assert len(ext_inp['internal_data_id'].in_nodes()) == 0
ext_inp['external_port_id'] = internal_id_count
internal_id_count += 1
for _, consumer, edge_attrs in body.out_edges(ext_inp['internal_data_id'].id, data=True):
real_ext_inp = {}
real_ext_inp.update(ext_inp)
consumer = Node(body, consumer)
if not consumer.has_valid('internal_layer_id'):
consumer['internal_layer_id'] = internal_id_count
internal_id_count += 1
if not 'internal_port_id' in edge_attrs:
edge_attrs['internal_port_id'] = internal_id_count
internal_id_count += 1
real_ext_inp['internal_layer_id'] = consumer['internal_layer_id']
real_ext_inp['internal_port_id'] = edge_attrs['internal_port_id']
real_external_inputs.append(real_ext_inp)
for ext_out in external_outputs:
assert ext_out['external_data_id'].id not in body.nodes()
assert ext_out['internal_data_id'].id in body.nodes()
ext_out['internal_data_id'] = Node(body, ext_out['internal_data_id'].id)
if ext_out['axis'] is not None:
# Insert unsqueezing resize at output port that has partitioning
assert not ext_out['internal_data_id'].has_valid('value')
reshape_op = Unsqueeze(body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze'))
reshape_dim_data = Const(body, {'name': ext_out['internal_data_id'].name + '/ReshapeDim',
'value': ext_out['axis']}).create_node_with_data()
ext_out['internal_data_id'] = reshape_op.create_node_with_data([ext_out['internal_data_id'],
reshape_dim_data])
# TODO: add here working with simple outputs
if not any([out_node.soft_get('op', None) == 'Result' for out_node in ext_out['internal_data_id'].out_nodes()]):
add_opoutput(body, ext_out['internal_data_id'].id, 0, False)
# assert len(ext_out['internal_data_id'].out_nodes()) == 0
assert len(ext_out['internal_data_id'].in_nodes()) == 1
if not 'internal_layer_id' in ext_out['internal_data_id'].in_node():
ext_out['internal_data_id'].in_node()['internal_layer_id'] = internal_id_count
internal_id_count += 1
if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
ext_out['internal_data_id'].in_edge()['internal_port_id'] = internal_id_count
internal_id_count += 1
ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node()['internal_layer_id']
ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge()['internal_port_id']
ext_out['external_port_id'] = internal_id_count
internal_id_count += 1
ti_op = TensorIterator(graph, {
'name': name + '/TensorIterator',
'body': body,
'in_ports_count': len(external_inputs),
'out_ports_count': len(external_outputs),
'input_port_map': [
{field: external_input[field] for field in
['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
'end']}
for external_input in real_external_inputs],
'output_port_map': [
{field: external_output[field] for field in
['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
'end']}
for external_output in external_outputs],
'back_edges': [
{field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port']}
for edge in real_back_edges],
})
ti_outs = ti_op.create_node_with_data(
inputs=[inp['external_data_id'] for inp in external_inputs],
edge_attrs=[{'external_port_id': inp['external_port_id']} for inp in external_inputs],
data_nodes=[out['external_data_id'] for out in external_outputs]
)
if not isinstance(ti_outs, list):
ti_outs = [ti_outs]
for i, out in enumerate(ti_outs):
out.in_edge()['external_port_id'] = external_outputs[i]['external_port_id']
ti = ti_outs[0].in_node()
TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
TensorIterator.normalize_internal_ids(ti)