Zero point optimization (#6628)
* Zero point optimization * Expand the equality to zero criteria
This commit is contained in:
committed by
GitHub
parent
205c23b382
commit
9acc3dfe68
@@ -6,10 +6,12 @@ from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Sub, Div, Mul, Negative
|
||||
from extensions.ops.elementwise import Sub, Div, Mul, Negative, Equal
|
||||
from extensions.ops.select import Select
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np, np_data_type_to_destination_type, packed_I4
|
||||
from mo.middle.pattern_match import apply_pattern
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
@@ -70,15 +72,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
scale = (output_high - output_low) / (input_high - input_low)
|
||||
WARNING: division by zero imposes restriction -- input_high can not be equal to input_low
|
||||
zero_point = input_low - output_low / scale
|
||||
|
||||
TODO: steps 5 and 6 are NOT IMPLEMENTED YET
|
||||
TODO: DOES LPT NEED IT???
|
||||
Step 5: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant.
|
||||
It is not always possible because of the quantized_dtype possible range of values.
|
||||
|
||||
Step 6: (Optional) From the nature of Subtract and Multiply operations they may be optimized out in cases:
|
||||
zero_point == 0
|
||||
scale == 1
|
||||
NOTE: if scale == 0 than zero_point is equal to zero too (achieved through Select operation)
|
||||
|
||||
BENEFITS:
|
||||
Such constant data packing reduces IR size (.bin file size)
|
||||
@@ -186,14 +180,24 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
descaled_output_low.in_port(0).connect(out_low)
|
||||
descaled_output_low.in_port(1).connect(scale.out_port(0))
|
||||
|
||||
shift = Sub(graph, {'name': name + '/zero_point'}).create_node()
|
||||
shift = Sub(graph, {'name': name + '/shift'}).create_node()
|
||||
shift.in_port(0).connect(in_low)
|
||||
shift.in_port(1).connect(descaled_output_low.out_port(0))
|
||||
|
||||
zero = Const(graph, {'name': name + '/zero', 'value': np.array(0, dtype=dst_type)}).create_node()
|
||||
scale_eq_zero = Equal(graph, {'name': name + '/scale_eq_zero'}).create_node()
|
||||
scale_eq_zero.in_port(0).connect(scale.out_port(0))
|
||||
scale_eq_zero.in_port(1).connect(zero.out_port(0))
|
||||
|
||||
if_scale_is_zero = Select(graph, {'name': name + '/zero_point'}).create_node()
|
||||
if_scale_is_zero.in_port(0).connect(scale_eq_zero.out_port(0))
|
||||
if_scale_is_zero.in_port(1).connect(zero.out_port(0))
|
||||
if_scale_is_zero.in_port(2).connect(shift.out_port(0))
|
||||
|
||||
# DeQuantize(x) == Mul(Sub(x, zero_point), scale)
|
||||
sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
|
||||
sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
|
||||
sub_zp.in_port(1).connect(shift.out_port(0))
|
||||
sub_zp.in_port(1).connect(if_scale_is_zero.out_port(0))
|
||||
|
||||
mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node()
|
||||
mul_scale.in_port(0).connect(sub_zp.out_port(0))
|
||||
@@ -218,3 +222,64 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
|
||||
self.quantize_data(fake_quantize, dst_type, quantized_type, mode)
|
||||
self.dequantize_data(fake_quantize, dst_type, quantized_type)
|
||||
|
||||
|
||||
class ZeroPointOptimizer(BackReplacementPattern):
|
||||
r"""
|
||||
Step 1: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant.
|
||||
It is not always possible because of the quantized_dtype possible range of values.
|
||||
|
||||
Step 2: From the nature of Subtract operation it may be optimized out if zero_point == 0
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [CompressQuantizeWeights]
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('const', dict(type='Const')),
|
||||
('const_d', dict()),
|
||||
('convert', dict(type='Convert')),
|
||||
('convert_d', dict()),
|
||||
('const_zp', dict(type='Const')),
|
||||
('const_zp_d', dict()),
|
||||
('sub', dict(type='Subtract')),
|
||||
],
|
||||
edges=[
|
||||
('const', 'const_d'),
|
||||
('const_d', 'convert'),
|
||||
('convert', 'convert_d'),
|
||||
('convert_d', 'sub', {'in': 0}),
|
||||
('const_zp', 'const_zp_d'),
|
||||
('const_zp_d', 'sub', {'in': 1}),
|
||||
]
|
||||
)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
||||
zero_point = match['const_zp'].out_port(0).data.get_value()
|
||||
assert zero_point is not None
|
||||
convert = match['convert']
|
||||
sub = match['sub']
|
||||
if np.allclose(zero_point, 0):
|
||||
sub.out_port(0).get_connection().set_source(convert.out_port(0))
|
||||
return
|
||||
|
||||
weights = match['const'].out_port(0).data.get_value()
|
||||
if weights is None or weights.dtype != np.int8:
|
||||
return
|
||||
dst_type = convert.dst_type
|
||||
|
||||
int8_zero_point = np.round(zero_point).astype(np.int8)
|
||||
adj_zero_point = (zero_point - int8_zero_point).astype(dst_type)
|
||||
|
||||
original = weights.astype(dst_type) - zero_point
|
||||
transformed = (weights - int8_zero_point).astype(np.int8) - adj_zero_point
|
||||
|
||||
if not np.allclose(original, transformed) or not np.allclose(adj_zero_point, 0, atol=1.e-04):
|
||||
return
|
||||
|
||||
match['const_d']['value'] = (weights - int8_zero_point).astype(np.int8)
|
||||
sub.out_port(0).get_connection().set_source(convert.out_port(0))
|
||||
|
||||
@@ -7,7 +7,7 @@ from argparse import Namespace
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
|
||||
from extensions.back.compress_quantized_weights import CompressQuantizeWeights, ZeroPointOptimizer
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Sub, Mul
|
||||
from extensions.ops.fakequantize import FakeQuantize
|
||||
@@ -37,7 +37,7 @@ def nodes_dict(original, transformed=None, levels=255, data=None, il=[-127], ih=
|
||||
|
||||
**regular_op_with_shaped_data(
|
||||
'FQ', shape, {'type': 'FakeQuantize', 'infer': FakeQuantize.infer, 'stop_value_propagation': True,
|
||||
'levels': levels, 'op': 'FakeQuantize'}),
|
||||
'levels': levels, 'op': 'FakeQuantize'}),
|
||||
|
||||
**valued_const_with_data('zp', np.array([0])),
|
||||
**valued_const_with_data('scale', np.array([1])),
|
||||
@@ -49,7 +49,7 @@ def nodes_dict(original, transformed=None, levels=255, data=None, il=[-127], ih=
|
||||
'mul', shape, {'type': 'Multiply', 'op': 'Mul', 'infer': lambda node: eltwise_infer(node, Mul.operation)}),
|
||||
|
||||
**result()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CompressionQuantizeDequantizeSeparateTest(unittest.TestCase):
|
||||
@@ -248,3 +248,70 @@ class NegativeCompressionTestLevels(unittest.TestCase):
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
@generator
|
||||
class ZeroPointOptimizerTestClass(unittest.TestCase):
|
||||
@generate(*[
|
||||
([-10, 7], [-1], [-9, 8], [0]),
|
||||
([-10, 7], [-0.99999999], [-9, 8], [0]),
|
||||
])
|
||||
def test_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point):
|
||||
nodes = lambda w, zp: {
|
||||
**valued_const_with_data('weights', np.array(w, dtype=np.int8)),
|
||||
**regular_op_with_shaped_data(
|
||||
'cast', len(w), {'type': 'Convert', 'op': 'Cast', 'infer': Cast.infer, 'dst_type': np.float32}),
|
||||
**valued_const_with_data('zp', np.array(zp, dtype=np.float32)),
|
||||
**regular_op_with_shaped_data(
|
||||
'sub', len(w),
|
||||
{'type': 'Subtract', 'op': 'Sub', 'infer': lambda node: eltwise_infer(node, Sub.operation)}),
|
||||
**result()
|
||||
}
|
||||
edges = [
|
||||
*connect("weights:0", "0:cast"),
|
||||
*connect("cast:0", "0:sub"),
|
||||
*connect("zp:0", "1:sub"),
|
||||
*connect("sub:0", "0:output"),
|
||||
]
|
||||
graph = build_graph(nodes(weights, zero_point), edges, nodes_with_edges_only=True)
|
||||
ZeroPointOptimizer().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
graph_ref = build_graph(nodes(adj_weights, adj_zero_point), [
|
||||
*connect("weights:0", "0:cast"),
|
||||
*connect("cast:0", "0:output"),
|
||||
], nodes_with_edges_only=True)
|
||||
graph_ref.clean_up()
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
@generate(*[
|
||||
([-128, 7], [1], [-128, 7], [1]),
|
||||
([127, 7], [-1], [127, 7], [-1]),
|
||||
])
|
||||
def test_negative_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point):
|
||||
nodes = lambda w, zp: {
|
||||
**valued_const_with_data('weights', np.array(w, dtype=np.int8)),
|
||||
**regular_op_with_shaped_data(
|
||||
'cast', len(w), {'type': 'Convert', 'op': 'Cast', 'infer': Cast.infer, 'dst_type': np.float32}),
|
||||
**valued_const_with_data('zp', np.array(zp, dtype=np.float32)),
|
||||
**regular_op_with_shaped_data(
|
||||
'sub', len(w),
|
||||
{'type': 'Subtract', 'op': 'Sub', 'infer': lambda node: eltwise_infer(node, Sub.operation)}),
|
||||
**result()
|
||||
}
|
||||
edges = [
|
||||
*connect("weights:0", "0:cast"),
|
||||
*connect("cast:0", "0:sub"),
|
||||
*connect("zp:0", "1:sub"),
|
||||
*connect("sub:0", "0:output"),
|
||||
]
|
||||
graph = build_graph(nodes(weights, zero_point), edges, nodes_with_edges_only=True)
|
||||
ZeroPointOptimizer().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
graph_ref = build_graph(nodes(adj_weights, adj_zero_point), edges, nodes_with_edges_only=True)
|
||||
graph_ref.clean_up()
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
Reference in New Issue
Block a user