Files
openvino/model-optimizer/mo/graph/perm_inputs.py
Evgenya Stepyreva cd391389ce [ MO ] Complete weights layout permutation (#2299)
* MO TF: FQPerChannel extractor

* [ MO ] Complete weights layout permutation

* removed deleted file out of BOM

* Bring back stashed changes

* Skip if no weights permutation

* Conditional permutation

* Comments
2020-09-18 14:42:16 +03:00

197 lines
9.7 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import networkx as nx
from extensions.ops.gather import Gather
from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.ops.const import Const
def get_node_with_permutation(node: Node, port_info: str):
node_type, port = port_info.split(':')
port = int(port)
return node.in_node(port) if node_type == 'input' else node.out_node(port)
def axis(op_node: Node, port_info: str, input_port: int):
"""
Performs layout change related transformation of the data on the in_port_idx port of op_node.
Translates shape indexes from one layout to another according to inverse permutation
Transformation inserts Gather operation with
permutation as 0-port input data and
actual data to translate as 1-port input indexes of Gather
For example:
NHWC Reduce operation has 0-port input with data of shape [1, 2, 3, 4] and
1-port input with axis indices [0, 1].
After translating such operation to NCHW layout:
0-port input shape = [1, 4, 2, 3]
1-port input axis indices = [0, 2]
"""
graph = op_node.graph
permutation_data_node = get_node_with_permutation(op_node, port_info)
assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
'port_info "{}".'.format(permutation_data_node.id,
op_node.id, port_info)
permutation = permutation_data_node.permutation
if len(permutation.perm) == 0:
return
data_node = op_node.in_node(input_port)
gather_name = op_node.soft_get('name', op_node.id) + '/AxisGather'
const = Const(graph, {'value': permutation.inv, 'name': gather_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
gather = Gather(graph, {'name': gather_name, 'need_shape_inference': True}).create_node_with_data(
[const, data_node, axis_const])
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
graph.add_edge(gather.id, op_node.id, **attrs)
graph.remove_edge(data_node.id, op_node.id)
op_node['need_shape_inference'] = True
def order(op_node: Node, port_info: str, input_port: int):
"""
Performs layout change related transformation of the data on the in_port_idx port of op_node.
Translates ordered shape indexes from one layout to another according to permutation
Transformation inserts two Gather operations
1 Gather reorders data to new layout according to direct permutation:
actual data to translate as 1-port input indexes of Gather and
permutation as 0-port input data
2 Gather translates shape indexes from one layout to another according to inverse permutation
permutation as 0-port input data and
actual data to translate as 1-port input indexes of Gather
For example:
NHWC Transpose operation has 0-port input with data of shape [1, 2, 3, 4] and
1-port input with new order indices [0, 1, 3, 2].
After translating such operation to NCHW layout:
0-port input shape = [1, 4, 2, 3]
1 phase (after first Gather insertion):
1-port input order indices = [0, 2, 1, 3]
2 phase (after second Gather insertion):
1-port input order indices = [0, 3, 2, 1]
"""
graph = op_node.graph
permutation_data_node = get_node_with_permutation(op_node, port_info)
assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
'port_info "{}".'.format(permutation_data_node.id,
op_node.id, port_info)
permutation = permutation_data_node.permutation
if len(permutation.perm) == 0:
return
data_node = op_node.in_node(input_port)
gather_name = op_node.soft_get('name', op_node.id) + '/OrderGather_1'
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
gather = Gather(graph, {'name': gather_name,
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
gather_1_name = op_node.soft_get('name', op_node.id) + '/OrderGather_2'
const_1 = Const(graph, {'value': permutation.inv, 'name': gather_1_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const_1 = Const(graph, {'value': int64_array(0), 'name': gather_1_name + '/axis'}).create_node_with_data()
gather_1 = Gather(graph, {'name': gather_1_name,
'need_shape_inference': True}).create_node_with_data([const_1, gather, axis_const_1])
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
graph.add_edge(gather_1.id, op_node.id, **attrs)
graph.remove_edge(data_node.id, op_node.id)
op_node['need_shape_inference'] = True
def shape(op_node: Node, port_info: str, input_port: int):
graph = op_node.graph
permutation_data_node = get_node_with_permutation(op_node, port_info)
assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
'port_info "{}".'.format(permutation_data_node.id,
op_node.id, port_info)
permutation = permutation_data_node.permutation
if len(permutation.perm) == 0:
return
data_node = op_node.in_node(input_port)
gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather'
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
gather = Gather(graph, {'name': gather_name,
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
graph.add_edge(gather.id, op_node.id, **attrs)
graph.remove_edge(data_node.id, op_node.id)
# need to run manually to override output shape value to resolve shape collision for nodes with
# 'correct_data_layout' output port attrs
op_node.infer(op_node)
def transpose(op_node: Node, port_info: str, input_port: int):
graph = op_node.graph
permutation_data_node = get_node_with_permutation(op_node, port_info)
assert permutation_data_node.has_and_set('permutation'), \
'Data node "{}" does not have permutation for node {}, port_info "{}".'.format(
permutation_data_node.id, op_node.id, port_info)
permutation = permutation_data_node.permutation
if len(permutation.perm) == 0:
return
transpose_name = op_node.soft_get('name', op_node.id) + '/Transpose'
from mo.front.tf.graph_utils import create_op_with_const_inputs # avoiding recursive imports
transpose = create_op_with_const_inputs(
graph, Transpose, {1: permutation.perm}, {'name': transpose_name, 'override_output_shape': True})
op_node.in_port(input_port).get_connection().insert_node(transpose)
class PermuteInputs:
common_inv_permutation = lambda node, port_info, input_port: axis(node, port_info, input_port)
input_permutes = {
'axis': common_inv_permutation,
'order': lambda node, port_info, input_port: order(node, port_info, input_port),
'shape': lambda node, port_info, input_port: shape(node, port_info, input_port),
'transpose': lambda node, port_info, input_port: transpose(node, port_info, input_port),
}
def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str):
"""
Sets input permutation attribute on the edge between node1 and node2.
Input permutation consists of function that perform input permutation and
input port info 'input' or 'output' + <port_number> that points on the input with PermuteAttr.Permutation which
current input depends on
"""
assert permutation_rule in self.input_permutes, 'No `{}` permutation rule in {}'.format(permutation_rule,
__class__.__name__)
nx.set_edge_attributes(G=node1.graph,
values={(node1.id, node2.id, 0): (self.input_permutes[permutation_rule],
port_info)},
name='input_permutation')