[MO] MulFakeQuantizeFuse - don't fuse if mul constant has zero or negative values (#7347)

This commit is contained in:
Mateusz Tabaka
2021-09-08 12:26:49 +02:00
committed by GitHub
parent 722891756f
commit 990b7e67da
2 changed files with 61 additions and 59 deletions

View File

@@ -72,6 +72,8 @@ class MulFakeQuantizeFuse(MiddleReplacementPattern):
return
mul_val = value_port.data.get_value()
if np.any(mul_val <= 0):
return
# Direct modifications to quantize 1-st and 2-nd port inputs are performed.
# So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same
@@ -80,33 +82,6 @@ class MulFakeQuantizeFuse(MiddleReplacementPattern):
# TODO: need some special processing for values that exactly equal to threshold
# Need to flip output_low and output_high for those elements that have multiplier < 0
if np.all(mul_val < 0):
mi_o_node = quantize.in_port(3).get_source()
ma_o_node = quantize.in_port(4).get_source()
quantize.in_port(3).disconnect()
quantize.in_port(4).disconnect()
mi_o_node.connect(quantize.in_port(4))
ma_o_node.connect(quantize.in_port(3))
elif np.any(mul_val < 0):
# Flipping values should be done on exclusive inputs of FakeQuantize node, so we duplicate them if needed
resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[3, 4])
# Successful flipping will be done on broadcasted arrays
mi_o_val = quantize.in_port(3).data.get_value()
ma_o_val = quantize.in_port(4).data.get_value()
mul_val, mi_o_val, ma_o_val = [np.array(a) for a in np.broadcast_arrays(mul_val, mi_o_val, ma_o_val)]
neg_idx = np.where(mul_val < 0)
mi_o_val[neg_idx], ma_o_val[neg_idx] = ma_o_val[neg_idx], mi_o_val[neg_idx]
# TODO: revert broadcasting where unnecessary
quantize.in_port(3).data.set_value(mi_o_val)
quantize.in_port(4).data.set_value(ma_o_val)
quantize.in_port(1).data.set_value(quantize.in_port(1).data.get_value() / mul_val)
if quantize.in_node(1).id != quantize.in_node(2).id:
quantize.in_port(2).data.set_value(quantize.in_port(2).data.get_value() / mul_val)

View File

@@ -110,7 +110,7 @@ class MulQuantizeFuseTest(unittest.TestCase):
def test_2(self):
graph = build_graph(nodes, edges, {
'mul': {'can_be_fused': True},
'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])},
'mul_const_data': {'shape': np.array([1]), 'value': np.array([2])},
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])},
'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])},
@@ -118,11 +118,11 @@ class MulQuantizeFuseTest(unittest.TestCase):
graph.stage = 'middle'
graph_ref = build_graph(nodes, edges_ref, {
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])},
'mi_o_data': {'shape': np.array([1]), 'value': np.array([1])},
'ma_o_data': {'shape': np.array([1]), 'value': np.array([0])},
'mi_i_data': {'shape': np.array([1]), 'value': np.array([10])},
'ma_i_data': {'shape': np.array([1]), 'value': np.array([-10])},
'mul_const_data': {'shape': np.array([1]), 'value': np.array([2])},
'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])},
'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])},
'mi_i_data': {'shape': np.array([1]), 'value': np.array([-5])},
'ma_i_data': {'shape': np.array([1]), 'value': np.array([5])},
}, nodes_with_edges_only=True)
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
@@ -131,31 +131,7 @@ class MulQuantizeFuseTest(unittest.TestCase):
self.assertTrue(flag, resp)
def test_3(self):
graph = build_graph(nodes, edges, {
'mul': {'can_be_fused': True},
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])},
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))},
'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))},
}, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref = build_graph(nodes, edges_ref, {
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])},
'mi_o_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[1]], [[0]], [[1]]])},
'ma_o_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[0]], [[1]], [[0]]])},
'mi_i_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[10]], [[-10]], [[10]]])},
'ma_i_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[-10]], [[10]], [[-10]]])},
}, nodes_with_edges_only=True)
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def negative_test_1(self):
def test_negative_1(self):
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref = build_graph(nodes, edges, nodes_with_edges_only=True)
@@ -165,7 +141,7 @@ class MulQuantizeFuseTest(unittest.TestCase):
self.assertTrue(flag, resp)
def negative_test_2(self):
def test_negative_2(self):
graph = build_graph(nodes, edges, {'mul': {'can_be_fused': False}}, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref = build_graph(nodes, edges, {'mul': {'can_be_fused': False}}, nodes_with_edges_only=True)
@@ -174,3 +150,54 @@ class MulQuantizeFuseTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_negative_3(self):
graph = build_graph(nodes, edges, {
'mul': {'can_be_fused': True},
'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])},
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])},
'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])},
}, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref = graph.copy()
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_negative_4(self):
graph = build_graph(nodes, edges, {
'mul': {'can_be_fused': True},
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])},
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))},
'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))},
}, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref = graph.copy()
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_negative_5(self):
graph = build_graph(nodes, edges, {
'mul': {'can_be_fused': True},
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[0]], [[1]], [[2]]])},
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))},
'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))},
}, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref = graph.copy()
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)