[ MO TF ] TF FakeQuantize normalization fix (#3180)

This commit is contained in:
Evgenya Stepyreva
2020-11-19 18:38:20 +03:00
committed by GitHub
parent ea0f6fce5a
commit 7b85eef15f

View File

@@ -16,15 +16,25 @@
from typing import Dict
from extensions.ops.elementwise import Sub, Div, Less, Greater, Round, Mul
from extensions.ops.elementwise import Sub, Div, Less, Round, Mul, Add, Greater
from extensions.ops.fakequantize import FakeQuantize
from extensions.ops.select import Select
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph, Node
from mo.ops.const import Const
class FakeQuantWithMinMaxVarsToQuantize(FrontReplacementOp):
"""
Performs FakeQuantize limits adjustment for min <= max following rules:
If 0 < min < max: min_adj = 0 and max_adj = max - min.
If min < max < 0: min_adj = min - max and max_adj = 0.
If min <= 0 <= max:
scale = (max - min) / (2^num_bits - 1),
min_adj = scale * round(min / scale) and max_adj = max + min_adj - min.
"""
op = "FakeQuantWithMinMaxVars"
enabled = True
@@ -32,82 +42,52 @@ class FakeQuantWithMinMaxVarsToQuantize(FrontReplacementOp):
node = match['op']
name = node.name
# Zero Point Nudging : Scale counting
f_min = node.in_port(1).get_source()
min_port_tuple = (node.in_port(1).get_source().node, node.in_port(1).get_source().idx)
max_port_tuple = (node.in_port(2).get_source().node, node.in_port(2).get_source().idx)
node.in_port(1).disconnect()
f_max = node.in_port(2).get_source()
node.in_port(2).disconnect()
f_diff = Sub(graph, {'name': name + '/float_range'}).create_node()
f_max.connect(f_diff.in_port(0))
f_min.connect(f_diff.in_port(1))
# make sure min < max
min_less_max = Less(graph, {'name': name + '/if_min_less_max'}).create_node([min_port_tuple, max_port_tuple])
minimum = Select(graph, {'name': name + '/minimum'}).create_node([min_less_max, min_port_tuple, max_port_tuple])
maximum = Select(graph, {'name': name + '/maximum'}).create_node([min_less_max, max_port_tuple, min_port_tuple])
quant_min_value = int(node.narrow_range)
quant_max_value = 2 ** node.num_bits - 1
i_diff = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node()
# to create zero of limits data type, we multiply it by integer zero
zero = create_op_node_with_second_input(graph, Mul, int64_array(0), {'name': name + '/zero'}, input_node=minimum)
scale = Div(graph, {'name': name + '/scale'}).create_node()
f_diff.out_port(0).connect(scale.in_port(0))
i_diff.out_port(0).connect(scale.in_port(1))
# if 0 < min < max: min_adj = 0 and max_adj = max - min
min_greater_zero = Greater(graph, {'name': name + '/if_minimum_greater_zero'}).create_node([minimum, zero])
max_minus_min = Sub(graph, {'name': name + '/max_minus_min'}).create_node([maximum, minimum])
minimum = Select(graph, {'name': name + '/first_adj_min'}).create_node([min_greater_zero, zero, minimum])
maximum = Select(graph, {'name': name + '/first_adj_max'}).create_node([min_greater_zero, max_minus_min, maximum])
# Zero Point Nudging : ZP from min counting
descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node()
f_min.connect(descaled_min.in_port(0))
scale.out_port(0).connect(descaled_min.in_port(1))
# if min < max < 0: min_adj = min - max and max_adj = 0
max_less_zero = Less(graph, {'name': name + '/if_max_less_zero'}).create_node([maximum, zero])
min_minus_max = Sub(graph, {'name': name + '/min_minus_max'}).create_node([minimum, maximum])
minimum = Select(graph, {'name': name + '/second_adj_min'}).create_node([max_less_zero, min_minus_max, minimum])
maximum = Select(graph, {'name': name + '/second_adj_max'}).create_node([max_less_zero, zero, maximum])
zero_point_from_min = Sub(graph, {'name': name + '/zero_point_from_min'}).create_node()
quant_min = Const(graph, {'value': quant_min_value, 'name': name + '/quant_min'}).create_node()
quant_min.out_port(0).connect(zero_point_from_min.in_port(0))
descaled_min.out_port(0).connect(zero_point_from_min.in_port(1))
# Zero Point Nudging : Nudged Zero Point counting
zp_lesser_q_mi = Less(graph, {'name': name + '/zero_point_from_min_less_quant_min'}).create_node()
zero_point_from_min.out_port(0).connect(zp_lesser_q_mi.in_port(0))
quant_min.out_port(0).connect(zp_lesser_q_mi.in_port(1))
zp_greater_q_ma = Greater(graph, {'name': name + '/zero_point_from_min_greater_quant_max'}).create_node()
zero_point_from_min.out_port(0).connect(zp_greater_q_ma.in_port(0))
quant_max = Const(graph, {'value': quant_max_value, 'name': name + '/quant_max'}).create_node()
quant_max.out_port(0).connect(zp_greater_q_ma.in_port(1))
rounded_zero_point_from_min = Round(graph, {'name': name + '/zero_point_from_min_rounding'}).create_node()
zero_point_from_min.out_port(0).connect(rounded_zero_point_from_min.in_port(0))
nudged_zero_point = Select(graph, {'name': name + '/nudging_zp_1_select_less_condition'}).create_node()
greater_condition = Select(graph, {'name': name + '/nudging_zp_2_select_greater_condition'}).create_node()
greater_condition.in_port(0).connect(zp_greater_q_ma.out_port(0))
greater_condition.in_port(1).connect(quant_max.out_port(0))
greater_condition.in_port(2).connect(rounded_zero_point_from_min.out_port(0))
nudged_zero_point.in_port(0).connect(zp_lesser_q_mi.out_port(0))
nudged_zero_point.in_port(1).connect(quant_max.out_port(0))
nudged_zero_point.in_port(2).connect(greater_condition.out_port(0))
nudged_i_min = Sub(graph, {'name': name + '/nudged_i_min'}).create_node()
quant_min.out_port(0).connect(nudged_i_min.in_port(0))
nudged_zero_point.out_port(0).connect(nudged_i_min.in_port(1))
nudged_i_max = Sub(graph, {'name': name + '/nudged_i_max'}).create_node()
quant_max.out_port(0).connect(nudged_i_max.in_port(0))
nudged_zero_point.out_port(0).connect(nudged_i_max.in_port(1))
nudged_min = Mul(graph, {'name': name + '/nudged_min'}).create_node()
nudged_i_min.out_port(0).connect(nudged_min.in_port(0))
scale.out_port(0).connect(nudged_min.in_port(1))
nudged_max = Mul(graph, {'name': name + '/nudged_max'}).create_node()
nudged_i_max.out_port(0).connect(nudged_max.in_port(0))
scale.out_port(0).connect(nudged_max.in_port(1))
nudged_min.out_port(0).connect(node.in_port(1))
nudged_max.out_port(0).connect(node.in_port(2))
# scale = (max - min) / (2 ^ num_bits - 1),
float_range = Sub(graph, {'name': name + '/float_range'}).create_node([maximum, minimum])
quant_min_value, quant_max_value = int(node.narrow_range), 2 ** node.num_bits - 1
int_range = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node()
scale = Div(graph, {'name': name + '/scale'}).create_node([float_range, int_range])
# min_adj = scale * round(min / scale)
descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node([minimum, scale])
rounded_descaled_min = Round(graph, {'name': name + '/rounded_descaled_min'}).create_node([descaled_min])
min_adj = Mul(graph, {'name': name + '/min_adj'}).create_node([scale, rounded_descaled_min])
# max_adj = max + min_adj - min.
adjustment = Sub(graph, {'name': name + '/limits_adjustment'}).create_node([min_adj, minimum])
max_adj = Add(graph, {'name': name + '/max_adj'}).create_node([maximum, adjustment])
# FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
node.add_input_port(3, skip_if_exist=True)
node.add_input_port(4, skip_if_exist=True)
node.in_port(3).connect(nudged_min.out_port(0))
node.in_port(4).connect(nudged_max.out_port(0))
node.in_port(1).connect(min_adj.out_port(0))
node.in_port(2).connect(max_adj.out_port(0))
node.in_port(3).connect(min_adj.out_port(0))
node.in_port(4).connect(max_adj.out_port(0))
FakeQuantize.update_node_stat(node, {'levels': node['levels']})