168 lines
7.6 KiB
Python
168 lines
7.6 KiB
Python
"""
|
|
Copyright (C) 2018-2020 Intel Corporation
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
import logging as log
|
|
|
|
import numpy as np
|
|
|
|
from extensions.middle.ApplyNHWCtoNCHWpermutation import ApplyNHWCtoNCHWpermutation
|
|
from extensions.middle.InsertLayoutPropagationTransposes import is_input_data_in_correct_layout, \
|
|
is_output_data_in_correct_layout
|
|
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
|
|
from extensions.middle.pass_separator import PostMiddleStart
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.graph.graph import Graph, Node
|
|
from mo.graph.perm_inputs import get_node_with_permutation
|
|
from mo.graph.port import Port
|
|
from mo.middle.replacement import MiddleReplacementPattern
|
|
from mo.utils.error import Error
|
|
|
|
|
|
class ApplyPermutation(MiddleReplacementPattern):
|
|
enabled = True
|
|
force_clean_up = True
|
|
# can't be turned on for Kaldi until permutation logic will be aligned
|
|
graph_condition = [lambda graph: graph.graph['fw'] != 'kaldi']
|
|
|
|
def run_after(self):
|
|
return [ApplyNHWCtoNCHWpermutation, PostMiddleStart]
|
|
|
|
def run_before(self):
|
|
return []
|
|
|
|
def find_and_replace_pattern(self, graph: Graph):
|
|
self.merge_nodes_permutations(graph)
|
|
self.permute_data_nodes_attrs(graph)
|
|
self.permute_op_nodes_attrs(graph)
|
|
self.shape_of_sub_graph_reinference(graph)
|
|
self.permute_input_data(graph)
|
|
graph.graph['layout'] = 'NCHW'
|
|
|
|
@staticmethod
|
|
def merge_nodes_permutations(graph: Graph):
|
|
# Iterate over all data nodes and check all permutations for similarity
|
|
# In case of equal permutations, this permutation will be set as attribute for data node
|
|
# otherwise exception will be raised
|
|
for node in graph.nodes():
|
|
node = Node(graph, node)
|
|
if node.kind != 'data':
|
|
continue
|
|
|
|
permutations = []
|
|
|
|
# Get all permutations from in edges
|
|
for in_node in node.in_nodes():
|
|
edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
|
|
if 'permutation' in edge_attrs:
|
|
permutations.append(edge_attrs['permutation'])
|
|
|
|
# Get all permutations from out edges
|
|
for out_node in node.out_nodes():
|
|
edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
|
|
if 'permutation' in edge_attrs:
|
|
permutations.append(edge_attrs['permutation'])
|
|
|
|
# Check that all permutations are equal
|
|
final_permutations = []
|
|
for p in permutations:
|
|
if p is not None:
|
|
final_permutations.append(p.perm)
|
|
else:
|
|
final_permutations.append(int64_array(np.arange(node.shape.size)))
|
|
|
|
if len(final_permutations) == 0:
|
|
continue
|
|
|
|
if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
|
|
raise Error('Permutations requested for {} data node are not equal! List of permutations: {}'
|
|
''.format(node.name, [p.perm for p in permutations]))
|
|
|
|
assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
|
|
node['permutation'] = permutations[0]
|
|
|
|
@staticmethod
|
|
def permute_data_nodes_attrs(graph: Graph):
|
|
# Iterate over all data nodes and apply permutation if exists
|
|
for node in graph.get_data_nodes():
|
|
if not node.has_valid('permutation') or \
|
|
all([attrs.get('input_permutation', False) for u, v, attrs in graph.out_edges(node.id, data=True)]):
|
|
continue
|
|
|
|
if len(
|
|
node.in_nodes()) != 0: # there are data nodes without input operation node inside the tensor iterator
|
|
edge_attrs = graph.get_edge_data(node.in_node(0).id, node.id)[0]
|
|
if is_output_data_in_correct_layout(node.in_node(0), edge_attrs['out']):
|
|
log.debug('Do not permute data node attrs for node "{}" output port "{}"'.format(node.in_node(0).id,
|
|
edge_attrs['out']))
|
|
continue
|
|
|
|
# Apply permutation for shape and value if exists
|
|
if len(node.permutation.perm) == 0:
|
|
continue
|
|
node.shape = np.array(node.shape)[node.permutation.perm]
|
|
if node.has_valid('value'):
|
|
assert len(node.value.shape) == len(node.permutation.perm), \
|
|
'Node {} has shape {} and permutation {} that does not match. Their lengths should be equal' \
|
|
''.format(node.name, node.value.shape, node.permutation.perm)
|
|
node.value = np.array(node.value.transpose(node.permutation.perm))
|
|
|
|
@staticmethod
|
|
def permute_op_nodes_attrs(graph: Graph):
|
|
for node in graph.get_op_nodes():
|
|
if node.has_valid('permute_attrs') and not node.has_and_set('nchw_layout'):
|
|
try:
|
|
node.permute_attrs.permute_attrs(node)
|
|
except Exception as e:
|
|
raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e))
|
|
|
|
@staticmethod
|
|
def permute_input_data(graph: Graph):
|
|
for node in graph.get_op_nodes():
|
|
input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in
|
|
node.in_edges().items() if edge_attrs.get('input_permutation') is not None]
|
|
for in_port, input_perm in input_permutations:
|
|
permutation, port_info = input_perm
|
|
direction, port = port_info.split(':')
|
|
port = int(port)
|
|
port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port)
|
|
permutation_data_node = get_node_with_permutation(node, port_info)
|
|
|
|
if permutation_data_node.has_and_set('permutation') and \
|
|
not is_input_data_in_correct_layout(node, in_port) and \
|
|
len(port_to_check.data.get_shape()) >= 4:
|
|
permutation(node, port_info, in_port)
|
|
if node.has_and_set('need_shape_inference'):
|
|
node.infer(node)
|
|
node.need_shape_inference = False
|
|
|
|
@staticmethod
|
|
def shape_of_sub_graph_reinference(graph: Graph):
|
|
"""
|
|
After layout permutation (shape change in data nodes) shape sub-graphs contain values in the old layout
|
|
To change that we execute full partial inference on the shape-of sub-graphs
|
|
"""
|
|
shape_ops = graph.get_op_nodes(op='ShapeOf')
|
|
for shape in shape_ops:
|
|
shape.infer(shape)
|
|
|
|
def reinfer_once(in_port: Port):
|
|
node = in_port.node
|
|
if not node.soft_get('reinferred', False):
|
|
node.infer(node)
|
|
node['reinferred'] = True
|
|
|
|
LayoutChangeForConstantShapePaths().find_shape_subgraph_endpoints(
|
|
out_ports=[shape.out_port(0) for shape in shape_ops], action=reinfer_once)
|