This reverts commit 1901087677
.
This commit is contained in:
parent
803a927e70
commit
3e89e7fc86
@ -4,12 +4,12 @@
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
from .utils import correct_node_overflow, correct_elt_overflow
|
||||
from .utils import correct_node_overflow
|
||||
from ...algorithm_selector import COMPRESSION_ALGORITHMS
|
||||
from ...quantization import fake_quantize as fqut
|
||||
from ....algorithms.algorithm import Algorithm
|
||||
from ....graph import model_utils as mu, node_utils as nu
|
||||
from ....graph.special_operations import OPERATIONS_WITH_WEIGHTS, ELTWISE_ADD_SUB
|
||||
from ....graph.special_operations import OPERATIONS_WITH_WEIGHTS
|
||||
from ....samplers.creator import create_sampler
|
||||
from ....statistics.functions import activations as acf
|
||||
from ....utils.logger import get_logger
|
||||
@ -43,16 +43,13 @@ class OverflowCorrection(Algorithm):
|
||||
"""
|
||||
activation_statistics = self._stats_collector.get_statistics_for_algorithm(self.name)
|
||||
|
||||
def get_node_orig_full_name(node):
|
||||
return node['orig_node_name'] if 'orig_node_name' in node else node.fullname
|
||||
|
||||
# overflow correction for input_scale * weight_scale
|
||||
weighted_nodes = mu.get_nodes_by_type(model, [n['type'] for n in OPERATIONS_WITH_WEIGHTS])
|
||||
weighted_nodes = [n for n in weighted_nodes if nu.node_with_quantized_weights(n)]
|
||||
for weighted_node in weighted_nodes:
|
||||
bias_node = nu.get_bias_for_node(weighted_node)
|
||||
output_node = weighted_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
|
||||
output_node_name = get_node_orig_full_name(output_node)
|
||||
output_node_name = output_node['orig_node_name'] if 'orig_node_name' in output_node \
|
||||
else output_node.fullname
|
||||
if output_node_name not in activation_statistics \
|
||||
or 'max_per_tensor' not in activation_statistics[output_node_name]:
|
||||
logger.debug('Skipping {}'.format(weighted_node.fullname))
|
||||
@ -67,27 +64,12 @@ class OverflowCorrection(Algorithm):
|
||||
if rescale_value:
|
||||
logger.debug('Weights and scales for node {} '
|
||||
'updated with scale coefficient: {}'.format(weighted_node.fullname, rescale_value))
|
||||
|
||||
# overflow correction for bias_scale / input_scale
|
||||
elt_nodes = mu.get_nodes_by_type(model, [n['type'] for n in ELTWISE_ADD_SUB])
|
||||
elt_nodes = [n for n in elt_nodes if nu.node_with_quantized_input_and_bias(n)]
|
||||
for elt_node in elt_nodes:
|
||||
elt_node_name = get_node_orig_full_name(elt_node)
|
||||
logger.debug('Processing {}'.format(elt_node.fullname))
|
||||
|
||||
input_rescale, input_fq, bias_rescale, bias_fq = correct_elt_overflow(elt_node)
|
||||
for (rescale, fq) in [(input_rescale, input_fq), (bias_rescale, bias_fq)]:
|
||||
if rescale:
|
||||
logger.debug('Weights and scales for node {} '
|
||||
'updated with scale coefficient: {}'.format(fq.fullname, rescale))
|
||||
|
||||
return model
|
||||
|
||||
def register_statistics(self, model, stats_collector):
|
||||
self._stats_collector = stats_collector
|
||||
conv_nodes = mu.get_nodes_by_type(model, [n['type'] for n in OPERATIONS_WITH_WEIGHTS])
|
||||
stats_layout = {}
|
||||
|
||||
for conv_node in conv_nodes:
|
||||
bias_node = nu.get_bias_for_node(conv_node)
|
||||
output_node = conv_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
|
||||
|
@ -69,40 +69,3 @@ def correct_node_overflow(weighted_node, node_statistics):
|
||||
nu.set_node_value(w_output_high, w_output_high_value)
|
||||
|
||||
return rescale_value
|
||||
|
||||
|
||||
def correct_elt_overflow(elt_node):
|
||||
input_fq = nu.get_node_input(elt_node, 0)
|
||||
bias_fq = nu.get_node_input(elt_node, 1)
|
||||
|
||||
input_scale = compute_scale(input_fq)
|
||||
bias_scale = compute_scale(bias_fq)
|
||||
|
||||
int16_type_max = np.iinfo(np.int16).max
|
||||
min_scale_factor = 1. / (int16_type_max * 64)
|
||||
|
||||
input_rescale = None
|
||||
# input_scale is too small
|
||||
if bias_scale / input_scale > int16_type_max and input_scale < min_scale_factor:
|
||||
input_rescale = min_scale_factor
|
||||
|
||||
i_output_low = nu.get_node_input(input_fq, 3)
|
||||
i_output_high = nu.get_node_input(input_fq, 4)
|
||||
|
||||
nu.set_node_value(i_output_low, -input_rescale * (input_fq.levels - 1) / 2)
|
||||
nu.set_node_value(i_output_high, input_rescale * (input_fq.levels - 1) / 2)
|
||||
|
||||
input_scale = input_rescale
|
||||
|
||||
bias_rescale = None
|
||||
# bias_scale is too big
|
||||
if bias_scale / input_scale > int16_type_max:
|
||||
bias_rescale = int16_type_max * input_scale
|
||||
|
||||
b_output_low = nu.get_node_input(bias_fq, 3)
|
||||
b_output_high = nu.get_node_input(bias_fq, 4)
|
||||
|
||||
nu.set_node_value(b_output_low, -bias_rescale * (bias_fq.levels - 1) / 2)
|
||||
nu.set_node_value(b_output_high, bias_rescale * (bias_fq.levels - 1) / 2)
|
||||
|
||||
return input_rescale, input_fq, bias_rescale, bias_fq
|
||||
|
@ -217,20 +217,6 @@ def node_with_quantized_weights(node):
|
||||
return False
|
||||
|
||||
|
||||
def node_with_quantized_input_and_bias(node):
|
||||
"""
|
||||
Check that node havs two quantized input (inputs on port 0 and 1).
|
||||
:param node: operation node
|
||||
:return: True if node has quantized inputs and False instead
|
||||
"""
|
||||
input_node = get_node_input(node, 0)
|
||||
bias_node = get_node_input(node, 1)
|
||||
if input_node is not None and input_node.type == 'FakeQuantize' and bias_node is not None and bias_node.type == 'FakeQuantize':
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_input_data_value(node: Node, port: int):
|
||||
"""
|
||||
Return value of data node for needed node at port
|
||||
|
@ -71,11 +71,6 @@ TYPES_TO_QUANTIZABLE_PORTS = {'LSTMSequence': [0, 1, 4, 5], 'GRUSequence': [0, 1
|
||||
ELTWISE_TYPES = ['Add', 'Multiply', 'Subtract', 'Divide', 'Less', 'LessEqual', 'Greater', 'GreaterEqual',
|
||||
'Equal', 'NotEqual', 'FloorMod', 'LogicalOr', 'LogicalXor', 'LogicalAnd', 'Maximum', 'Minimum']
|
||||
|
||||
ELTWISE_ADD_SUB = [
|
||||
{'type': 'Add'},
|
||||
{'type': 'Subtract'}
|
||||
]
|
||||
|
||||
|
||||
def is_eltwise(node):
|
||||
return node.type in ELTWISE_TYPES
|
||||
|
Loading…
Reference in New Issue
Block a user