Files
openvino/model-optimizer/extensions/middle/ConvertLayoutDependentOperations.py
Anton Chetverikov 56916ace61 Fix const node non-deterministic names (part 2) (#1081)
* Fix non-deterministic node names generation in the Model Optimizer (part 2)
2020-07-07 09:37:48 +03:00

100 lines
5.1 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.transpose import Transpose
from mo.front.common.layout import indices_mapping
from mo.graph.graph import Node, Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.op import Op, PermuteAttrs
class ConvertLayoutDependentOperations(MiddleReplacementPattern):
"""
This pass finds all convolutions and in case if layout of convolution differs from graph layout
we insert permutes before and after convolution and convert convolution attributes
"""
enabled = True
def run_after(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]
def find_and_replace_pattern(self, graph: Graph):
for node in list(graph.nodes()):
node = Node(graph, node)
node_name = node.soft_get('name', node.id)
# Check that node layout mismatch with graph layout
# For example: NHWC and NCHW or NCDHW and NDHWC
if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
graph.graph['layout']]:
input = node.in_node()
output = node.out_node()
# Calculate permutation for further Transpose operations
if graph.graph['layout'] == 'NCHW':
# if Node has NCHW and graph has NHWC layout
permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(len(node.layout))
else:
# if Node has NHWC and graph has NCHW layout
permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(len(node.layout))
# Schematic representation of transformation below
#
# \ NCHW NCHW
# NHWC -- \ | permutation permutation |
# data-->Convolution(example)-->data -- / | | NCHW | |
# / data->Transpose->data->Convolution->data->Transpose->data
# 1. Insert input Transpose
# This Transpose will permute input from original input layout to operation layout
edge_attrs = graph.get_edge_data(input.id, node.id)[0]
graph.remove_edge(input.id, node.id)
input_permute_name = node_name + '/input_transpose'
input_order_const = Const(graph, {'name': input_permute_name + '/order',
'value': permutation.perm}).create_node_with_data()
input_permute_op = Transpose(graph, {'name': input_permute_name})
input_permute_data_node = input_permute_op.create_node_with_data([input, input_order_const])
graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
# 2. Insert output Transpose
# This Transpose will permute output from operation layout to original input layout
edge_attrs = graph.get_edge_data(node.id, output.id)[0]
graph.remove_edge(node.id, output.id)
input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
edge_attrs)
output_permute_name = node_name + '/output_transpose'
output_order_const = Const(graph, {'name': output_permute_name + '/order',
'value': permutation.inv}).create_node_with_data()
output_permute_op = Transpose(graph, {'name': output_permute_name}
).create_node_with_data([input_data_node, output_order_const],
data_nodes=output)
# 3. Add permutations for Node
# Here we use permutation mechanism where data nodes takes permutation attribute.
# And then we call permute_attrs method that permutes node attributes according to permutations on
# data nodes.
node.in_node()['permutation'] = permutation
node.out_node()['permutation'] = permutation
node.permute_attrs.permute_attrs(node)
node.in_node()['permutation'] = None
node.out_node()['permutation'] = None