214 lines
8.1 KiB
Python
214 lines
8.1 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import numpy as np
|
|
|
|
from extensions.back.TransposeReduceFusing import TransposeReduce
|
|
from extensions.ops.transpose import Transpose
|
|
from mo.back.replacement import BackReplacementPattern
|
|
from mo.front.caffe.extractors.utils import get_canonical_axis_index
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
|
from mo.graph.graph import Graph, Node
|
|
from mo.ops.const import Const
|
|
from mo.ops.shape import Shape
|
|
from mo.ops.unsqueeze import Unsqueeze
|
|
from mo.utils.shape import node_to_get_shape_value_of_indices, new_shape_node_from_shape_nodes
|
|
|
|
|
|
class MatMulConstTransposesExtraction(BackReplacementPattern):
|
|
"""
|
|
Resolves transpose_a(b) key from MatMul operation if corresponding input is constant by inserting Transpose,
|
|
that gets const folded while graph clean up execution
|
|
"""
|
|
|
|
enabled = True
|
|
force_clean_up = True
|
|
|
|
@staticmethod
|
|
def pattern():
|
|
return dict(
|
|
nodes=[('matmul', dict(kind='op', op='MatMul'))],
|
|
edges=[]
|
|
)
|
|
|
|
@staticmethod
|
|
def insert_transpose(node, in_port_idx):
|
|
graph = node.graph
|
|
name = node.soft_get('name', node.id)
|
|
|
|
assert in_port_idx in node.in_ports() and not node.in_port(in_port_idx).disconnected(), \
|
|
'Input port with index {} should be connected for node {}'.format(in_port_idx, name)
|
|
|
|
in_port = node.in_port(in_port_idx)
|
|
port_shape = in_port.data.get_shape()
|
|
assert port_shape is not None, \
|
|
'Shape is unknown for input port with index {} for node {}'.format(in_port_idx, name)
|
|
|
|
transpose_order = list(range(port_shape.size))
|
|
transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]
|
|
|
|
transpose = create_op_node_with_second_input(graph, Transpose, int64_array(transpose_order),
|
|
{'name': name + '/{}_port_transpose'.format(in_port_idx)})
|
|
|
|
port_source = in_port.get_source()
|
|
in_port.get_connection().set_source(transpose.out_port(0))
|
|
transpose.in_port(0).connect(port_source)
|
|
|
|
transpose['override_output_shape'] = True
|
|
|
|
@staticmethod
|
|
def replace_pattern(graph: Graph, match: dict):
|
|
node = match['matmul']
|
|
|
|
if not node.has_and_set('transpose_b'):
|
|
B_shape = node.in_port(1).data.get_shape()
|
|
B_value = node.in_port(1).data.get_value()
|
|
FQ_on_weights = node.in_port(1).get_source().node.has_and_set('stop_value_propagation')
|
|
if (B_value is not None or FQ_on_weights) and B_shape[B_shape != 1].size <= 2:
|
|
MatMulConstTransposesExtraction.insert_transpose(node, 1)
|
|
node['transpose_b'] = True
|
|
|
|
|
|
class PullTransposeThroughFQUp(BackReplacementPattern):
|
|
r"""
|
|
BEFORE AFTER
|
|
T T T T T
|
|
\ \ | / / \ \ | / /
|
|
FakeQuantize FakeQuantize
|
|
| |
|
|
Transpose next_op
|
|
|
|
|
next_op
|
|
|
|
`T` is Transpose for short
|
|
"""
|
|
enabled = True
|
|
force_clean_up = True
|
|
|
|
def run_after(self):
|
|
# in case FQ->Transpose->Reduce we should first try to optimize out Transpose
|
|
return [MatMulConstTransposesExtraction, TransposeReduce]
|
|
|
|
@staticmethod
|
|
def pattern():
|
|
return dict(
|
|
nodes=[
|
|
('fq', dict(kind='op', type='FakeQuantize')),
|
|
('data', dict()),
|
|
('transpose', dict(kind='op', type='Transpose')),
|
|
],
|
|
edges=[
|
|
('fq', 'data'),
|
|
('data', 'transpose'),
|
|
]
|
|
)
|
|
|
|
@staticmethod
|
|
def replace_pattern(graph: Graph, match: dict):
|
|
fq = match['fq']
|
|
|
|
if len(fq.out_port(0).get_destinations()) > 1:
|
|
# FQ should have only one child -- Transpose for optimization
|
|
return
|
|
|
|
transpose = match['transpose']
|
|
name = fq.soft_get('name', fq.id)
|
|
|
|
input_shape = transpose.in_port(0).data.get_shape()
|
|
|
|
# detaching transpose from the graph
|
|
transpose.out_port(0).get_connection().set_source(transpose.in_port(0).get_connection().get_source())
|
|
transpose.in_port(0).disconnect()
|
|
|
|
for idx, port in fq.in_ports().items():
|
|
transpose_copy = transpose.copy_node({'override_output_shape': True})
|
|
transpose.in_port(1).get_source().connect(transpose_copy.in_port(1))
|
|
|
|
start_port = transpose_copy.in_port(0)
|
|
|
|
idxs = np.arange(len(input_shape) - len(port.data.get_shape()))
|
|
if idxs.size != 0:
|
|
axis = Const(graph, {'name': name + '/in_{}_unsqueeze_axis'.format(idx),
|
|
'value': int64_array(idxs)}).create_node()
|
|
unsqueeze = Unsqueeze(graph, {'name': name + '/in_{}_unsqueeze'.format(idx)}).create_node()
|
|
axis.out_port(0).connect(unsqueeze.in_port(1))
|
|
unsqueeze.out_port(0).connect(transpose_copy.in_port(0))
|
|
start_port = unsqueeze.in_port(0)
|
|
|
|
src = port.get_source()
|
|
port.get_connection().set_source(transpose_copy.out_port(0))
|
|
src.connect(start_port)
|
|
|
|
|
|
class SmartReshape_HC_Reshape_MatMul(BackReplacementPattern):
|
|
r"""
|
|
Relaxes hard-coded input of Reshape in such sub-graphs:
|
|
|
|
input_1 Constant
|
|
\ /
|
|
Reshape input_2
|
|
\ /
|
|
MatMul
|
|
|
|
|
"""
|
|
enabled = True
|
|
force_clean_up = True
|
|
|
|
def run_after(self):
|
|
return [MatMulConstTransposesExtraction]
|
|
|
|
def pattern(self):
|
|
return dict(
|
|
nodes=[
|
|
('output_shape', dict(type='Const')),
|
|
('output_shape_d', dict()),
|
|
('reshape', dict(type='Reshape')),
|
|
('reshape_d', dict()),
|
|
('other_input', dict(type=lambda t: t not in ['Reshape', 'Transpose'])),
|
|
('other_input_d', dict()),
|
|
('matmul', dict(type='MatMul')),
|
|
],
|
|
edges=[
|
|
('output_shape', 'output_shape_d'),
|
|
('output_shape_d', 'reshape', {'in': 1}),
|
|
('reshape', 'reshape_d'),
|
|
('reshape_d', 'matmul'),
|
|
('other_input', 'other_input_d'),
|
|
('other_input_d', 'matmul'),
|
|
]
|
|
)
|
|
|
|
def replace_pattern(self, graph: Graph, match: dict):
|
|
matmul = match['matmul']
|
|
reshape = match['reshape']
|
|
other_input_port_idx = 0 if match['matmul'].in_port(0).get_source().node.id == match['other_input'].id else 1
|
|
shape_source = match['matmul'].in_port(other_input_port_idx).get_source()
|
|
initial_reshape_pattern = reshape.in_port(1).data.get_value()
|
|
if len(initial_reshape_pattern) != 2:
|
|
return
|
|
|
|
reshape_is_A_input = matmul.in_port(0).get_source().node.id == reshape.id
|
|
if reshape_is_A_input:
|
|
idx = -1 if matmul.transpose_b else -2
|
|
else:
|
|
idx = -2 if matmul.transpose_a else -1
|
|
idx = get_canonical_axis_index(initial_reshape_pattern, idx)
|
|
|
|
shape_name = shape_source.node.soft_get('name', shape_source.node.id)
|
|
shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node()
|
|
shape.in_port(0).connect(shape_source)
|
|
C = node_to_get_shape_value_of_indices(shape, [idx])
|
|
N = Const(graph, {'name': shape_name + '/MinusOne', 'value': int64_array([-1])}).create_node()
|
|
|
|
if len(initial_reshape_pattern) == 2:
|
|
if reshape_is_A_input:
|
|
reshape_pattern = [C, N] if matmul.transpose_a else [N, C]
|
|
else:
|
|
reshape_pattern = [N, C] if matmul.transpose_b else [C, N]
|
|
new_reshape_pattern = new_shape_node_from_shape_nodes(reshape_pattern)
|
|
reshape.in_port(1).get_connection().set_source(new_reshape_pattern.out_port(0))
|
|
else:
|
|
return
|
|
|