This reverts commit 1901087677
.
This commit is contained in:
parent
803a927e70
commit
3e89e7fc86
@ -4,12 +4,12 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .utils import correct_node_overflow, correct_elt_overflow
|
from .utils import correct_node_overflow
|
||||||
from ...algorithm_selector import COMPRESSION_ALGORITHMS
|
from ...algorithm_selector import COMPRESSION_ALGORITHMS
|
||||||
from ...quantization import fake_quantize as fqut
|
from ...quantization import fake_quantize as fqut
|
||||||
from ....algorithms.algorithm import Algorithm
|
from ....algorithms.algorithm import Algorithm
|
||||||
from ....graph import model_utils as mu, node_utils as nu
|
from ....graph import model_utils as mu, node_utils as nu
|
||||||
from ....graph.special_operations import OPERATIONS_WITH_WEIGHTS, ELTWISE_ADD_SUB
|
from ....graph.special_operations import OPERATIONS_WITH_WEIGHTS
|
||||||
from ....samplers.creator import create_sampler
|
from ....samplers.creator import create_sampler
|
||||||
from ....statistics.functions import activations as acf
|
from ....statistics.functions import activations as acf
|
||||||
from ....utils.logger import get_logger
|
from ....utils.logger import get_logger
|
||||||
@ -43,16 +43,13 @@ class OverflowCorrection(Algorithm):
|
|||||||
"""
|
"""
|
||||||
activation_statistics = self._stats_collector.get_statistics_for_algorithm(self.name)
|
activation_statistics = self._stats_collector.get_statistics_for_algorithm(self.name)
|
||||||
|
|
||||||
def get_node_orig_full_name(node):
|
|
||||||
return node['orig_node_name'] if 'orig_node_name' in node else node.fullname
|
|
||||||
|
|
||||||
# overflow correction for input_scale * weight_scale
|
|
||||||
weighted_nodes = mu.get_nodes_by_type(model, [n['type'] for n in OPERATIONS_WITH_WEIGHTS])
|
weighted_nodes = mu.get_nodes_by_type(model, [n['type'] for n in OPERATIONS_WITH_WEIGHTS])
|
||||||
weighted_nodes = [n for n in weighted_nodes if nu.node_with_quantized_weights(n)]
|
weighted_nodes = [n for n in weighted_nodes if nu.node_with_quantized_weights(n)]
|
||||||
for weighted_node in weighted_nodes:
|
for weighted_node in weighted_nodes:
|
||||||
bias_node = nu.get_bias_for_node(weighted_node)
|
bias_node = nu.get_bias_for_node(weighted_node)
|
||||||
output_node = weighted_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
|
output_node = weighted_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
|
||||||
output_node_name = get_node_orig_full_name(output_node)
|
output_node_name = output_node['orig_node_name'] if 'orig_node_name' in output_node \
|
||||||
|
else output_node.fullname
|
||||||
if output_node_name not in activation_statistics \
|
if output_node_name not in activation_statistics \
|
||||||
or 'max_per_tensor' not in activation_statistics[output_node_name]:
|
or 'max_per_tensor' not in activation_statistics[output_node_name]:
|
||||||
logger.debug('Skipping {}'.format(weighted_node.fullname))
|
logger.debug('Skipping {}'.format(weighted_node.fullname))
|
||||||
@ -67,27 +64,12 @@ class OverflowCorrection(Algorithm):
|
|||||||
if rescale_value:
|
if rescale_value:
|
||||||
logger.debug('Weights and scales for node {} '
|
logger.debug('Weights and scales for node {} '
|
||||||
'updated with scale coefficient: {}'.format(weighted_node.fullname, rescale_value))
|
'updated with scale coefficient: {}'.format(weighted_node.fullname, rescale_value))
|
||||||
|
|
||||||
# overflow correction for bias_scale / input_scale
|
|
||||||
elt_nodes = mu.get_nodes_by_type(model, [n['type'] for n in ELTWISE_ADD_SUB])
|
|
||||||
elt_nodes = [n for n in elt_nodes if nu.node_with_quantized_input_and_bias(n)]
|
|
||||||
for elt_node in elt_nodes:
|
|
||||||
elt_node_name = get_node_orig_full_name(elt_node)
|
|
||||||
logger.debug('Processing {}'.format(elt_node.fullname))
|
|
||||||
|
|
||||||
input_rescale, input_fq, bias_rescale, bias_fq = correct_elt_overflow(elt_node)
|
|
||||||
for (rescale, fq) in [(input_rescale, input_fq), (bias_rescale, bias_fq)]:
|
|
||||||
if rescale:
|
|
||||||
logger.debug('Weights and scales for node {} '
|
|
||||||
'updated with scale coefficient: {}'.format(fq.fullname, rescale))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def register_statistics(self, model, stats_collector):
|
def register_statistics(self, model, stats_collector):
|
||||||
self._stats_collector = stats_collector
|
self._stats_collector = stats_collector
|
||||||
conv_nodes = mu.get_nodes_by_type(model, [n['type'] for n in OPERATIONS_WITH_WEIGHTS])
|
conv_nodes = mu.get_nodes_by_type(model, [n['type'] for n in OPERATIONS_WITH_WEIGHTS])
|
||||||
stats_layout = {}
|
stats_layout = {}
|
||||||
|
|
||||||
for conv_node in conv_nodes:
|
for conv_node in conv_nodes:
|
||||||
bias_node = nu.get_bias_for_node(conv_node)
|
bias_node = nu.get_bias_for_node(conv_node)
|
||||||
output_node = conv_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
|
output_node = conv_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
|
||||||
|
@ -69,40 +69,3 @@ def correct_node_overflow(weighted_node, node_statistics):
|
|||||||
nu.set_node_value(w_output_high, w_output_high_value)
|
nu.set_node_value(w_output_high, w_output_high_value)
|
||||||
|
|
||||||
return rescale_value
|
return rescale_value
|
||||||
|
|
||||||
|
|
||||||
def correct_elt_overflow(elt_node):
|
|
||||||
input_fq = nu.get_node_input(elt_node, 0)
|
|
||||||
bias_fq = nu.get_node_input(elt_node, 1)
|
|
||||||
|
|
||||||
input_scale = compute_scale(input_fq)
|
|
||||||
bias_scale = compute_scale(bias_fq)
|
|
||||||
|
|
||||||
int16_type_max = np.iinfo(np.int16).max
|
|
||||||
min_scale_factor = 1. / (int16_type_max * 64)
|
|
||||||
|
|
||||||
input_rescale = None
|
|
||||||
# input_scale is too small
|
|
||||||
if bias_scale / input_scale > int16_type_max and input_scale < min_scale_factor:
|
|
||||||
input_rescale = min_scale_factor
|
|
||||||
|
|
||||||
i_output_low = nu.get_node_input(input_fq, 3)
|
|
||||||
i_output_high = nu.get_node_input(input_fq, 4)
|
|
||||||
|
|
||||||
nu.set_node_value(i_output_low, -input_rescale * (input_fq.levels - 1) / 2)
|
|
||||||
nu.set_node_value(i_output_high, input_rescale * (input_fq.levels - 1) / 2)
|
|
||||||
|
|
||||||
input_scale = input_rescale
|
|
||||||
|
|
||||||
bias_rescale = None
|
|
||||||
# bias_scale is too big
|
|
||||||
if bias_scale / input_scale > int16_type_max:
|
|
||||||
bias_rescale = int16_type_max * input_scale
|
|
||||||
|
|
||||||
b_output_low = nu.get_node_input(bias_fq, 3)
|
|
||||||
b_output_high = nu.get_node_input(bias_fq, 4)
|
|
||||||
|
|
||||||
nu.set_node_value(b_output_low, -bias_rescale * (bias_fq.levels - 1) / 2)
|
|
||||||
nu.set_node_value(b_output_high, bias_rescale * (bias_fq.levels - 1) / 2)
|
|
||||||
|
|
||||||
return input_rescale, input_fq, bias_rescale, bias_fq
|
|
||||||
|
@ -217,20 +217,6 @@ def node_with_quantized_weights(node):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def node_with_quantized_input_and_bias(node):
|
|
||||||
"""
|
|
||||||
Check that node havs two quantized input (inputs on port 0 and 1).
|
|
||||||
:param node: operation node
|
|
||||||
:return: True if node has quantized inputs and False instead
|
|
||||||
"""
|
|
||||||
input_node = get_node_input(node, 0)
|
|
||||||
bias_node = get_node_input(node, 1)
|
|
||||||
if input_node is not None and input_node.type == 'FakeQuantize' and bias_node is not None and bias_node.type == 'FakeQuantize':
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_input_data_value(node: Node, port: int):
|
def get_input_data_value(node: Node, port: int):
|
||||||
"""
|
"""
|
||||||
Return value of data node for needed node at port
|
Return value of data node for needed node at port
|
||||||
|
@ -71,11 +71,6 @@ TYPES_TO_QUANTIZABLE_PORTS = {'LSTMSequence': [0, 1, 4, 5], 'GRUSequence': [0, 1
|
|||||||
ELTWISE_TYPES = ['Add', 'Multiply', 'Subtract', 'Divide', 'Less', 'LessEqual', 'Greater', 'GreaterEqual',
|
ELTWISE_TYPES = ['Add', 'Multiply', 'Subtract', 'Divide', 'Less', 'LessEqual', 'Greater', 'GreaterEqual',
|
||||||
'Equal', 'NotEqual', 'FloorMod', 'LogicalOr', 'LogicalXor', 'LogicalAnd', 'Maximum', 'Minimum']
|
'Equal', 'NotEqual', 'FloorMod', 'LogicalOr', 'LogicalXor', 'LogicalAnd', 'Maximum', 'Minimum']
|
||||||
|
|
||||||
ELTWISE_ADD_SUB = [
|
|
||||||
{'type': 'Add'},
|
|
||||||
{'type': 'Subtract'}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_eltwise(node):
|
def is_eltwise(node):
|
||||||
return node.type in ELTWISE_TYPES
|
return node.type in ELTWISE_TYPES
|
||||||
|
Loading…
Reference in New Issue
Block a user