[POT] Update OverflowCorrection algo for nodes without bias (#10687)

* Update OverflowCorrection algo for nodes without bias

* Pylint line fix

* Update OC with the last add name

* Pylint fix
This commit is contained in:
Nikita Malinin 2022-03-04 14:50:44 +03:00 committed by GitHub
parent 32edd596e3
commit 69ad9e80e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -44,12 +44,11 @@ class OverflowCorrection(Algorithm):
weighted_nodes = [n for n in weighted_nodes if nu.node_with_quantized_weights(n)]
for weighted_node in weighted_nodes:
bias_node = nu.get_bias_for_node(weighted_node)
if bias_node is None:
continue
add_node = nu.get_node_output(bias_node, 0)[0]
add_node_name = add_node.fullname
if add_node_name not in activation_statistics \
or 'max_per_tensor' not in activation_statistics[add_node_name]:
output_node = weighted_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
output_node_name = output_node['orig_node_name'] if 'orig_node_name' in output_node \
else output_node.fullname
if output_node_name not in activation_statistics \
or 'max_per_tensor' not in activation_statistics[output_node_name]:
logger.debug('Skipping {}'.format(weighted_node.fullname))
continue
logger.debug('Processing {}'.format(weighted_node.fullname))
@ -57,7 +56,8 @@ class OverflowCorrection(Algorithm):
if weight_fq.levels <= np.iinfo(np.uint8).max:
logger.debug('Skipping {} due to INT8 weights quantization'.format(weighted_node.fullname))
continue
rescale_value = correct_node_overflow(weighted_node, activation_statistics[add_node_name]['max_per_tensor'])
rescale_value = correct_node_overflow(weighted_node,
activation_statistics[output_node_name]['max_per_tensor'])
if rescale_value:
logger.debug('Weights and scales for node {} '
'updated with scale coefficient: {}'.format(weighted_node.fullname, rescale_value))
@ -69,10 +69,8 @@ class OverflowCorrection(Algorithm):
stats_layout = {}
for conv_node in conv_nodes:
bias_node = nu.get_bias_for_node(conv_node)
if bias_node is None:
continue
add_node = nu.get_node_output(bias_node, 0)[0]
stats_layout[add_node.fullname] = {'max_per_tensor': acf.abs_max_per_tensor}
output_node = conv_node if bias_node is None else nu.get_node_output(bias_node, 0)[0]
stats_layout[output_node.fullname] = {'max_per_tensor': acf.abs_max_per_tensor}
quantized_model = deepcopy(model)
fqut.insert_fake_quantize_nodes(self._config, quantized_model)
layers_mapping = fqut.create_renamed_layers_mapping(quantized_model, stats_layout)