Zero point optimization (#6628)

* Zero point optimization

* Expand the equality to zero criteria
This commit is contained in:
Evgenya Stepyreva
2021-07-25 13:33:29 +03:00
committed by GitHub
parent 205c23b382
commit 9acc3dfe68
2 changed files with 147 additions and 15 deletions

View File

@@ -6,10 +6,12 @@ from typing import Dict
import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Sub, Div, Mul, Negative
from extensions.ops.elementwise import Sub, Div, Mul, Negative, Equal
from extensions.ops.select import Select
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph, Node
from mo.middle.passes.convert_data_type import data_type_str_to_np, np_data_type_to_destination_type, packed_I4
from mo.middle.pattern_match import apply_pattern
from mo.ops.const import Const
@@ -70,15 +72,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
scale = (output_high - output_low) / (input_high - input_low)
WARNING: division by zero imposes restriction -- input_high can not be equal to input_low
zero_point = input_low - output_low / scale
TODO: steps 5 and 6 are NOT IMPLEMENTED YET
TODO: DOES LPT NEED IT???
Step 5: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant.
It is not always possible because of the quantized_dtype possible range of values.
Step 6: (Optional) From the nature of Subtract and Multiply operations they may be optimized out in cases:
zero_point == 0
scale == 1
NOTE: if scale == 0 than zero_point is equal to zero too (achieved through Select operation)
BENEFITS:
Such constant data packing reduces IR size (.bin file size)
@@ -186,14 +180,24 @@ class CompressQuantizeWeights(BackReplacementPattern):
descaled_output_low.in_port(0).connect(out_low)
descaled_output_low.in_port(1).connect(scale.out_port(0))
shift = Sub(graph, {'name': name + '/zero_point'}).create_node()
shift = Sub(graph, {'name': name + '/shift'}).create_node()
shift.in_port(0).connect(in_low)
shift.in_port(1).connect(descaled_output_low.out_port(0))
zero = Const(graph, {'name': name + '/zero', 'value': np.array(0, dtype=dst_type)}).create_node()
scale_eq_zero = Equal(graph, {'name': name + '/scale_eq_zero'}).create_node()
scale_eq_zero.in_port(0).connect(scale.out_port(0))
scale_eq_zero.in_port(1).connect(zero.out_port(0))
if_scale_is_zero = Select(graph, {'name': name + '/zero_point'}).create_node()
if_scale_is_zero.in_port(0).connect(scale_eq_zero.out_port(0))
if_scale_is_zero.in_port(1).connect(zero.out_port(0))
if_scale_is_zero.in_port(2).connect(shift.out_port(0))
# DeQuantize(x) == Mul(Sub(x, zero_point), scale)
sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
sub_zp.in_port(1).connect(shift.out_port(0))
sub_zp.in_port(1).connect(if_scale_is_zero.out_port(0))
mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node()
mul_scale.in_port(0).connect(sub_zp.out_port(0))
@@ -218,3 +222,64 @@ class CompressQuantizeWeights(BackReplacementPattern):
self.quantize_data(fake_quantize, dst_type, quantized_type, mode)
self.dequantize_data(fake_quantize, dst_type, quantized_type)
class ZeroPointOptimizer(BackReplacementPattern):
r"""
Step 1: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant.
It is not always possible because of the quantized_dtype possible range of values.
Step 2: From the nature of Subtract operation it may be optimized out if zero_point == 0
"""
enabled = True
force_clean_up = True
def run_after(self):
return [CompressQuantizeWeights]
def pattern(self):
return dict(
nodes=[
('const', dict(type='Const')),
('const_d', dict()),
('convert', dict(type='Convert')),
('convert_d', dict()),
('const_zp', dict(type='Const')),
('const_zp_d', dict()),
('sub', dict(type='Subtract')),
],
edges=[
('const', 'const_d'),
('const_d', 'convert'),
('convert', 'convert_d'),
('convert_d', 'sub', {'in': 0}),
('const_zp', 'const_zp_d'),
('const_zp_d', 'sub', {'in': 1}),
]
)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
zero_point = match['const_zp'].out_port(0).data.get_value()
assert zero_point is not None
convert = match['convert']
sub = match['sub']
if np.allclose(zero_point, 0):
sub.out_port(0).get_connection().set_source(convert.out_port(0))
return
weights = match['const'].out_port(0).data.get_value()
if weights is None or weights.dtype != np.int8:
return
dst_type = convert.dst_type
int8_zero_point = np.round(zero_point).astype(np.int8)
adj_zero_point = (zero_point - int8_zero_point).astype(dst_type)
original = weights.astype(dst_type) - zero_point
transformed = (weights - int8_zero_point).astype(np.int8) - adj_zero_point
if not np.allclose(original, transformed) or not np.allclose(adj_zero_point, 0, atol=1.e-04):
return
match['const_d']['value'] = (weights - int8_zero_point).astype(np.int8)
sub.out_port(0).get_connection().set_source(convert.out_port(0))

View File

@@ -7,7 +7,7 @@ from argparse import Namespace
import numpy as np
from generator import generator, generate
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
from extensions.back.compress_quantized_weights import CompressQuantizeWeights, ZeroPointOptimizer
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Sub, Mul
from extensions.ops.fakequantize import FakeQuantize
@@ -37,7 +37,7 @@ def nodes_dict(original, transformed=None, levels=255, data=None, il=[-127], ih=
**regular_op_with_shaped_data(
'FQ', shape, {'type': 'FakeQuantize', 'infer': FakeQuantize.infer, 'stop_value_propagation': True,
'levels': levels, 'op': 'FakeQuantize'}),
'levels': levels, 'op': 'FakeQuantize'}),
**valued_const_with_data('zp', np.array([0])),
**valued_const_with_data('scale', np.array([1])),
@@ -49,7 +49,7 @@ def nodes_dict(original, transformed=None, levels=255, data=None, il=[-127], ih=
'mul', shape, {'type': 'Multiply', 'op': 'Mul', 'infer': lambda node: eltwise_infer(node, Mul.operation)}),
**result()
}
}
class CompressionQuantizeDequantizeSeparateTest(unittest.TestCase):
@@ -248,3 +248,70 @@ class NegativeCompressionTestLevels(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
@generator
class ZeroPointOptimizerTestClass(unittest.TestCase):
@generate(*[
([-10, 7], [-1], [-9, 8], [0]),
([-10, 7], [-0.99999999], [-9, 8], [0]),
])
def test_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point):
nodes = lambda w, zp: {
**valued_const_with_data('weights', np.array(w, dtype=np.int8)),
**regular_op_with_shaped_data(
'cast', len(w), {'type': 'Convert', 'op': 'Cast', 'infer': Cast.infer, 'dst_type': np.float32}),
**valued_const_with_data('zp', np.array(zp, dtype=np.float32)),
**regular_op_with_shaped_data(
'sub', len(w),
{'type': 'Subtract', 'op': 'Sub', 'infer': lambda node: eltwise_infer(node, Sub.operation)}),
**result()
}
edges = [
*connect("weights:0", "0:cast"),
*connect("cast:0", "0:sub"),
*connect("zp:0", "1:sub"),
*connect("sub:0", "0:output"),
]
graph = build_graph(nodes(weights, zero_point), edges, nodes_with_edges_only=True)
ZeroPointOptimizer().find_and_replace_pattern(graph)
graph.clean_up()
graph_ref = build_graph(nodes(adj_weights, adj_zero_point), [
*connect("weights:0", "0:cast"),
*connect("cast:0", "0:output"),
], nodes_with_edges_only=True)
graph_ref.clean_up()
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
@generate(*[
([-128, 7], [1], [-128, 7], [1]),
([127, 7], [-1], [127, 7], [-1]),
])
def test_negative_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point):
nodes = lambda w, zp: {
**valued_const_with_data('weights', np.array(w, dtype=np.int8)),
**regular_op_with_shaped_data(
'cast', len(w), {'type': 'Convert', 'op': 'Cast', 'infer': Cast.infer, 'dst_type': np.float32}),
**valued_const_with_data('zp', np.array(zp, dtype=np.float32)),
**regular_op_with_shaped_data(
'sub', len(w),
{'type': 'Subtract', 'op': 'Sub', 'infer': lambda node: eltwise_infer(node, Sub.operation)}),
**result()
}
edges = [
*connect("weights:0", "0:cast"),
*connect("cast:0", "0:sub"),
*connect("zp:0", "1:sub"),
*connect("sub:0", "0:output"),
]
graph = build_graph(nodes(weights, zero_point), edges, nodes_with_edges_only=True)
ZeroPointOptimizer().find_and_replace_pattern(graph)
graph.clean_up()
graph_ref = build_graph(nodes(adj_weights, adj_zero_point), edges, nodes_with_edges_only=True)
graph_ref.clean_up()
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)