Files
openvino/model-optimizer/extensions/middle/quantize_fuses.py
Evgenya Stepyreva cd391389ce [ MO ] Complete weights layout permutation (#2299)
* MO TF: FQPerChannel extractor

* [ MO ] Complete weights layout permutation

* removed deleted file out of BOM

* Bring back stashed changes

* Skip if no weights permutation

* Conditional permutation

* Comments
2020-09-18 14:42:16 +03:00

127 lines
5.7 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.middle.BinarizeWeightsM1P1 import BinarizeWeightsM1P1
from extensions.middle.DeleteControlFlowEdges import DeleteControlFlowEdges
from extensions.middle.EltwiseChecker import EltwiseChecker
from mo.graph.graph import Graph
from mo.middle.passes.fusing.helpers import get_value_in_port
from mo.middle.replacement import MiddleReplacementPattern
class MarkNodesToFuseUpToFakeQuantize(MiddleReplacementPattern):
"""
Marks special nodes that could be pulled through Quantize operation.
Sets `fuse_up_to_quantize_ports` parameter to list of indexes of input ports of Quantize operation
where specified node should appear.
"""
enabled = True
def run_after(self):
return [DeleteControlFlowEdges]
def run_before(self):
return []
@staticmethod
def mark_fusable_muls_on_weights(graph):
for node in graph.get_op_nodes(op='Mul'):
children = node.out_port(0).get_destinations()
if len(children) > 1 or children[0].node.soft_get('type') not in ['Convolution', 'Deconvolution', 'MatMul']:
continue
value_in_port = get_value_in_port(node)
if value_in_port is None:
continue
value_shape = value_in_port.data.get_shape()
non_one_axis = np.argwhere(value_shape != 1)
if non_one_axis.size != 1:
continue
non_one_axis = non_one_axis.item(0)
node['can_be_fused'] = True
EltwiseChecker().mark_eltwise_node(node, non_one_axis)
def find_and_replace_pattern(self, graph: Graph):
# to prevent fusing of non per channel lin ops, we run EltwiseChecker to mark nodes with can_be_fused attribute
EltwiseChecker().find_and_replace_pattern(graph)
self.mark_fusable_muls_on_weights(graph)
eltwise_nodes = graph.get_op_nodes(op='Mul', can_be_fused=True) + \
graph.get_op_nodes(op='Sub', can_be_fused=True) + \
graph.get_op_nodes(op='Add', can_be_fused=True)
for elt in eltwise_nodes:
if elt.in_port(0).data.get_value() is not None or elt.in_port(1).data.get_value() is not None:
elt['fuse_up_to_quantize_ports'] = [3, 4]
slice = graph.get_op_nodes(op='Slice')
for sl in slice:
sl['fuse_up_to_quantize_ports'] = [0]
class FakeQuantizeFuse(MiddleReplacementPattern):
"""
Pulls nodes containing `fuse_up_to_quantize_ports` parameter (node to fuse) through Quantize operation
If `fuse_up_to_quantize_ports` list contains one input port to which node to fuse should be delivered,
replacer reconnects edges.
If `fuse_up_to_quantize_ports` list contains more than one input port to which node to fuse should be delivered,
replacer reconnects edges of first port from `fuse_up_to_quantize_ports` list, for other ports
replacer duplicates node to fuse (duplicate connections of inputs of node to fuse to duplicates of it)
"""
enabled = True
def run_after(self):
return [MarkNodesToFuseUpToFakeQuantize]
def run_before(self):
return [BinarizeWeightsM1P1]
def find_and_replace_pattern(self, graph: Graph):
for quantize_node in graph.get_op_nodes(op='FakeQuantize'):
while len(quantize_node.out_port(0).get_destinations()) == 1:
if not quantize_node.out_port(0).get_destination().node.has_valid('fuse_up_to_quantize_ports'):
break
fuse_node = quantize_node.out_port(0).get_destination().node
quantize_to_mul_in_port_index = quantize_node.out_port(0).get_destination().idx
# connecting the rest of model after mul to quantize, mul node hangs on quantize
fuse_node.out_port(0).get_connection().set_source(quantize_node.out_port(0))
# mul node is disconnected from the graph
fuse_node.in_port(quantize_to_mul_in_port_index).disconnect()
first_port_fusion = True
for in_quantize_port in fuse_node['fuse_up_to_quantize_ports']:
fuse_node_duplicate = fuse_node
if not first_port_fusion:
fuse_node_duplicate = fuse_node.copy_node(
{'in_ports_count': len(fuse_node.in_ports()),
'out_ports_count': len(fuse_node.out_ports())})
quantize_node.in_port(in_quantize_port).get_connection().set_destination(
fuse_node_duplicate.in_port(quantize_to_mul_in_port_index))
fuse_node_duplicate.out_port(0).connect(quantize_node.in_port(in_quantize_port))
if not first_port_fusion:
for idx, port in fuse_node.in_ports().items():
if idx == quantize_to_mul_in_port_index:
continue
port.get_source().connect(fuse_node_duplicate.in_port(idx))
fuse_node_duplicate.infer(fuse_node_duplicate)
first_port_fusion = False