[ MO TF ] TF FakeQuantize normalization fix (#3180)
This commit is contained in:
committed by
GitHub
parent
ea0f6fce5a
commit
7b85eef15f
@@ -16,15 +16,25 @@
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from extensions.ops.elementwise import Sub, Div, Less, Greater, Round, Mul
|
||||
from extensions.ops.elementwise import Sub, Div, Less, Round, Mul, Add, Greater
|
||||
from extensions.ops.fakequantize import FakeQuantize
|
||||
from extensions.ops.select import Select
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsToQuantize(FrontReplacementOp):
|
||||
"""
|
||||
Performs FakeQuantize limits adjustment for min <= max following rules:
|
||||
If 0 < min < max: min_adj = 0 and max_adj = max - min.
|
||||
If min < max < 0: min_adj = min - max and max_adj = 0.
|
||||
If min <= 0 <= max:
|
||||
scale = (max - min) / (2^num_bits - 1),
|
||||
min_adj = scale * round(min / scale) and max_adj = max + min_adj - min.
|
||||
"""
|
||||
op = "FakeQuantWithMinMaxVars"
|
||||
enabled = True
|
||||
|
||||
@@ -32,82 +42,52 @@ class FakeQuantWithMinMaxVarsToQuantize(FrontReplacementOp):
|
||||
node = match['op']
|
||||
name = node.name
|
||||
|
||||
# Zero Point Nudging : Scale counting
|
||||
f_min = node.in_port(1).get_source()
|
||||
min_port_tuple = (node.in_port(1).get_source().node, node.in_port(1).get_source().idx)
|
||||
max_port_tuple = (node.in_port(2).get_source().node, node.in_port(2).get_source().idx)
|
||||
|
||||
node.in_port(1).disconnect()
|
||||
f_max = node.in_port(2).get_source()
|
||||
node.in_port(2).disconnect()
|
||||
|
||||
f_diff = Sub(graph, {'name': name + '/float_range'}).create_node()
|
||||
f_max.connect(f_diff.in_port(0))
|
||||
f_min.connect(f_diff.in_port(1))
|
||||
# make sure min < max
|
||||
min_less_max = Less(graph, {'name': name + '/if_min_less_max'}).create_node([min_port_tuple, max_port_tuple])
|
||||
minimum = Select(graph, {'name': name + '/minimum'}).create_node([min_less_max, min_port_tuple, max_port_tuple])
|
||||
maximum = Select(graph, {'name': name + '/maximum'}).create_node([min_less_max, max_port_tuple, min_port_tuple])
|
||||
|
||||
quant_min_value = int(node.narrow_range)
|
||||
quant_max_value = 2 ** node.num_bits - 1
|
||||
i_diff = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node()
|
||||
# to create zero of limits data type, we multiply it by integer zero
|
||||
zero = create_op_node_with_second_input(graph, Mul, int64_array(0), {'name': name + '/zero'}, input_node=minimum)
|
||||
|
||||
scale = Div(graph, {'name': name + '/scale'}).create_node()
|
||||
f_diff.out_port(0).connect(scale.in_port(0))
|
||||
i_diff.out_port(0).connect(scale.in_port(1))
|
||||
# if 0 < min < max: min_adj = 0 and max_adj = max - min
|
||||
min_greater_zero = Greater(graph, {'name': name + '/if_minimum_greater_zero'}).create_node([minimum, zero])
|
||||
max_minus_min = Sub(graph, {'name': name + '/max_minus_min'}).create_node([maximum, minimum])
|
||||
minimum = Select(graph, {'name': name + '/first_adj_min'}).create_node([min_greater_zero, zero, minimum])
|
||||
maximum = Select(graph, {'name': name + '/first_adj_max'}).create_node([min_greater_zero, max_minus_min, maximum])
|
||||
|
||||
# Zero Point Nudging : ZP from min counting
|
||||
descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node()
|
||||
f_min.connect(descaled_min.in_port(0))
|
||||
scale.out_port(0).connect(descaled_min.in_port(1))
|
||||
# if min < max < 0: min_adj = min - max and max_adj = 0
|
||||
max_less_zero = Less(graph, {'name': name + '/if_max_less_zero'}).create_node([maximum, zero])
|
||||
min_minus_max = Sub(graph, {'name': name + '/min_minus_max'}).create_node([minimum, maximum])
|
||||
minimum = Select(graph, {'name': name + '/second_adj_min'}).create_node([max_less_zero, min_minus_max, minimum])
|
||||
maximum = Select(graph, {'name': name + '/second_adj_max'}).create_node([max_less_zero, zero, maximum])
|
||||
|
||||
zero_point_from_min = Sub(graph, {'name': name + '/zero_point_from_min'}).create_node()
|
||||
quant_min = Const(graph, {'value': quant_min_value, 'name': name + '/quant_min'}).create_node()
|
||||
quant_min.out_port(0).connect(zero_point_from_min.in_port(0))
|
||||
descaled_min.out_port(0).connect(zero_point_from_min.in_port(1))
|
||||
|
||||
# Zero Point Nudging : Nudged Zero Point counting
|
||||
zp_lesser_q_mi = Less(graph, {'name': name + '/zero_point_from_min_less_quant_min'}).create_node()
|
||||
zero_point_from_min.out_port(0).connect(zp_lesser_q_mi.in_port(0))
|
||||
quant_min.out_port(0).connect(zp_lesser_q_mi.in_port(1))
|
||||
|
||||
zp_greater_q_ma = Greater(graph, {'name': name + '/zero_point_from_min_greater_quant_max'}).create_node()
|
||||
zero_point_from_min.out_port(0).connect(zp_greater_q_ma.in_port(0))
|
||||
quant_max = Const(graph, {'value': quant_max_value, 'name': name + '/quant_max'}).create_node()
|
||||
quant_max.out_port(0).connect(zp_greater_q_ma.in_port(1))
|
||||
|
||||
rounded_zero_point_from_min = Round(graph, {'name': name + '/zero_point_from_min_rounding'}).create_node()
|
||||
zero_point_from_min.out_port(0).connect(rounded_zero_point_from_min.in_port(0))
|
||||
|
||||
nudged_zero_point = Select(graph, {'name': name + '/nudging_zp_1_select_less_condition'}).create_node()
|
||||
greater_condition = Select(graph, {'name': name + '/nudging_zp_2_select_greater_condition'}).create_node()
|
||||
|
||||
greater_condition.in_port(0).connect(zp_greater_q_ma.out_port(0))
|
||||
greater_condition.in_port(1).connect(quant_max.out_port(0))
|
||||
greater_condition.in_port(2).connect(rounded_zero_point_from_min.out_port(0))
|
||||
|
||||
nudged_zero_point.in_port(0).connect(zp_lesser_q_mi.out_port(0))
|
||||
nudged_zero_point.in_port(1).connect(quant_max.out_port(0))
|
||||
nudged_zero_point.in_port(2).connect(greater_condition.out_port(0))
|
||||
|
||||
nudged_i_min = Sub(graph, {'name': name + '/nudged_i_min'}).create_node()
|
||||
quant_min.out_port(0).connect(nudged_i_min.in_port(0))
|
||||
nudged_zero_point.out_port(0).connect(nudged_i_min.in_port(1))
|
||||
|
||||
nudged_i_max = Sub(graph, {'name': name + '/nudged_i_max'}).create_node()
|
||||
quant_max.out_port(0).connect(nudged_i_max.in_port(0))
|
||||
nudged_zero_point.out_port(0).connect(nudged_i_max.in_port(1))
|
||||
|
||||
nudged_min = Mul(graph, {'name': name + '/nudged_min'}).create_node()
|
||||
nudged_i_min.out_port(0).connect(nudged_min.in_port(0))
|
||||
scale.out_port(0).connect(nudged_min.in_port(1))
|
||||
|
||||
nudged_max = Mul(graph, {'name': name + '/nudged_max'}).create_node()
|
||||
nudged_i_max.out_port(0).connect(nudged_max.in_port(0))
|
||||
scale.out_port(0).connect(nudged_max.in_port(1))
|
||||
|
||||
nudged_min.out_port(0).connect(node.in_port(1))
|
||||
nudged_max.out_port(0).connect(node.in_port(2))
|
||||
# scale = (max - min) / (2 ^ num_bits - 1),
|
||||
float_range = Sub(graph, {'name': name + '/float_range'}).create_node([maximum, minimum])
|
||||
quant_min_value, quant_max_value = int(node.narrow_range), 2 ** node.num_bits - 1
|
||||
int_range = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node()
|
||||
scale = Div(graph, {'name': name + '/scale'}).create_node([float_range, int_range])
|
||||
# min_adj = scale * round(min / scale)
|
||||
descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node([minimum, scale])
|
||||
rounded_descaled_min = Round(graph, {'name': name + '/rounded_descaled_min'}).create_node([descaled_min])
|
||||
min_adj = Mul(graph, {'name': name + '/min_adj'}).create_node([scale, rounded_descaled_min])
|
||||
# max_adj = max + min_adj - min.
|
||||
adjustment = Sub(graph, {'name': name + '/limits_adjustment'}).create_node([min_adj, minimum])
|
||||
max_adj = Add(graph, {'name': name + '/max_adj'}).create_node([maximum, adjustment])
|
||||
|
||||
# FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
|
||||
node.add_input_port(3, skip_if_exist=True)
|
||||
node.add_input_port(4, skip_if_exist=True)
|
||||
|
||||
node.in_port(3).connect(nudged_min.out_port(0))
|
||||
node.in_port(4).connect(nudged_max.out_port(0))
|
||||
node.in_port(1).connect(min_adj.out_port(0))
|
||||
node.in_port(2).connect(max_adj.out_port(0))
|
||||
node.in_port(3).connect(min_adj.out_port(0))
|
||||
node.in_port(4).connect(max_adj.out_port(0))
|
||||
|
||||
FakeQuantize.update_node_stat(node, {'levels': node['levels']})
|
||||
|
||||
Reference in New Issue
Block a user