Files
openvino/model-optimizer/extensions/middle/ConcatOptimization.py
Evgeny Lazarev c7bcbb576c Updated ConcatOptimization to support Concat with 0D input of one dimension (#2012)
* Updated ConcatOptimization transformation to work when one dimension of input to Concat is 0D

* Fixed ConcatOptimization transformation to reconnect input edges to Concat

* Completely re-written ConcatOptimization

* Updated Concat0D optimization transformation

* Fixed order of traversing Concat input ports

* Refactored ConcatOptimization transformation to use `delete_input_port` function

* Detele trailing unconnected ports in the ConcatOptimization.py

* Cleaner implementation of ConcatOptimization + unit test
2020-09-02 10:21:23 +03:00

130 lines
5.0 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
from extensions.middle.fusings import Fusing
from extensions.middle.pass_separator import PostMiddleStart
from mo.graph.graph import Node, Graph
from mo.middle.replacement import MiddleReplacementPattern
class ConcatOptimization(MiddleReplacementPattern):
# This optimization reduces number of edges between Concat operations
# that significantly reduce memory consumption
enabled = True
graph_condition = [lambda graph: graph.graph['cmd_params'].enable_concat_optimization]
def run_after(self):
return [Fusing]
def run_before(self):
return [PostMiddleStart]
def find_and_replace_pattern(self, graph: Graph):
mp = {}
used = {}
for node in graph.get_op_nodes(type='Concat'):
in_nodes = tuple([node.in_node(idx).id for idx in range(len(node.in_nodes()))])
out_node = (node.id, node.out_node().id)
if in_nodes in mp:
log.warning("Something is weird! {} and {}".format(node.id, mp[in_nodes]))
else:
mp.update({in_nodes: out_node})
used.update({node.id: {x: False for x in in_nodes}})
for key in mp.keys():
replacers = []
for i in range(len(key)):
for j in range(i + 1, len(key)):
arr = tuple(key[i:j + 1])
if arr in mp.keys() and arr != key:
replacers.append((len(arr), arr))
replacers.sort(reverse=True)
concat_id = mp[key][0]
for ln, arr in replacers:
# Check that we can do it!!!
we_can = True
for x in arr:
if used[concat_id][x]:
we_can = False
break
if not we_can:
continue
for x in arr:
used[concat_id][x] = True
edge_attrs = graph.get_edge_data(arr[0], concat_id)[0]
for in_node in arr:
graph.remove_edge(in_node, concat_id)
new_input = mp[arr][1]
out_port = len(Node(graph, new_input).out_nodes()) + 1
edge_attrs['out'] = out_port
graph.add_edge(new_input, concat_id, **edge_attrs)
# Renumber 'in' attrs
concat_node = Node(graph, concat_id)
ln = len(concat_node.in_nodes())
ports = [x for x in concat_node.in_nodes().keys()]
ports.sort()
p_id = 0
for p in ports:
in_node = concat_node.in_nodes()[p]
graph[in_node.id][concat_id][0]['in'] = p_id
p_id += 1
class ConcatOdInputEraserAndPortsReconnect(MiddleReplacementPattern):
"""
The transformation performs two actions with Concat operations:
1. Disconnects empty inputs (input tensor has at least one input dimension equal to 0)
2. Renumber Concat inputs to be 0, 1, 2,...
"""
enabled = True
force_clean_up = True
def find_and_replace_pattern(self, graph: Graph):
for concat in graph.get_op_nodes(type='Concat'):
for in_port in concat.in_ports().values():
if not in_port.disconnected():
shape = in_port.data.get_shape()
assert shape is not None
if 0 in shape:
concat.delete_input_port(in_port.idx)
connected_ports = [port for port_idx, port in sorted(concat.in_ports().items()) if not port.disconnected()]
assert len(connected_ports), 'Concat "{}" have no inputs after removing inputs with 0 dimensions' \
''.format(concat.soft_get('name', concat.id))
max_port_index = max([port_idx for port_idx in concat.in_ports().keys()])
# re-connect input ports sequentially and remove all not used
port_idx_to_connect = 0
for port_idx in range(max_port_index + 1):
if concat.is_in_port_connected(port_idx):
if port_idx != port_idx_to_connect:
concat.add_input_port(port_idx_to_connect, skip_if_exist=True)
concat.in_port(port_idx).get_connection().set_destination(concat.in_port(port_idx_to_connect))
port_idx_to_connect += 1
elif port_idx in concat.in_ports():
concat.delete_input_port(port_idx)