[MO] MulFakeQuantizeFuse - don't fuse if mul constant has zero or negative values (#7347)
This commit is contained in:
@@ -72,6 +72,8 @@ class MulFakeQuantizeFuse(MiddleReplacementPattern):
|
||||
return
|
||||
|
||||
mul_val = value_port.data.get_value()
|
||||
if np.any(mul_val <= 0):
|
||||
return
|
||||
|
||||
# Direct modifications to quantize 1-st and 2-nd port inputs are performed.
|
||||
# So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same
|
||||
@@ -80,33 +82,6 @@ class MulFakeQuantizeFuse(MiddleReplacementPattern):
|
||||
|
||||
# TODO: need some special processing for values that exactly equal to threshold
|
||||
|
||||
# Need to flip output_low and output_high for those elements that have multiplier < 0
|
||||
if np.all(mul_val < 0):
|
||||
mi_o_node = quantize.in_port(3).get_source()
|
||||
ma_o_node = quantize.in_port(4).get_source()
|
||||
|
||||
quantize.in_port(3).disconnect()
|
||||
quantize.in_port(4).disconnect()
|
||||
|
||||
mi_o_node.connect(quantize.in_port(4))
|
||||
ma_o_node.connect(quantize.in_port(3))
|
||||
|
||||
elif np.any(mul_val < 0):
|
||||
# Flipping values should be done on exclusive inputs of FakeQuantize node, so we duplicate them if needed
|
||||
resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[3, 4])
|
||||
|
||||
# Successful flipping will be done on broadcasted arrays
|
||||
mi_o_val = quantize.in_port(3).data.get_value()
|
||||
ma_o_val = quantize.in_port(4).data.get_value()
|
||||
mul_val, mi_o_val, ma_o_val = [np.array(a) for a in np.broadcast_arrays(mul_val, mi_o_val, ma_o_val)]
|
||||
|
||||
neg_idx = np.where(mul_val < 0)
|
||||
mi_o_val[neg_idx], ma_o_val[neg_idx] = ma_o_val[neg_idx], mi_o_val[neg_idx]
|
||||
|
||||
# TODO: revert broadcasting where unnecessary
|
||||
quantize.in_port(3).data.set_value(mi_o_val)
|
||||
quantize.in_port(4).data.set_value(ma_o_val)
|
||||
|
||||
quantize.in_port(1).data.set_value(quantize.in_port(1).data.get_value() / mul_val)
|
||||
if quantize.in_node(1).id != quantize.in_node(2).id:
|
||||
quantize.in_port(2).data.set_value(quantize.in_port(2).data.get_value() / mul_val)
|
||||
|
||||
@@ -110,7 +110,7 @@ class MulQuantizeFuseTest(unittest.TestCase):
|
||||
def test_2(self):
|
||||
graph = build_graph(nodes, edges, {
|
||||
'mul': {'can_be_fused': True},
|
||||
'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])},
|
||||
'mul_const_data': {'shape': np.array([1]), 'value': np.array([2])},
|
||||
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
|
||||
'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])},
|
||||
'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])},
|
||||
@@ -118,11 +118,11 @@ class MulQuantizeFuseTest(unittest.TestCase):
|
||||
graph.stage = 'middle'
|
||||
graph_ref = build_graph(nodes, edges_ref, {
|
||||
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
|
||||
'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])},
|
||||
'mi_o_data': {'shape': np.array([1]), 'value': np.array([1])},
|
||||
'ma_o_data': {'shape': np.array([1]), 'value': np.array([0])},
|
||||
'mi_i_data': {'shape': np.array([1]), 'value': np.array([10])},
|
||||
'ma_i_data': {'shape': np.array([1]), 'value': np.array([-10])},
|
||||
'mul_const_data': {'shape': np.array([1]), 'value': np.array([2])},
|
||||
'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])},
|
||||
'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])},
|
||||
'mi_i_data': {'shape': np.array([1]), 'value': np.array([-5])},
|
||||
'ma_i_data': {'shape': np.array([1]), 'value': np.array([5])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
|
||||
@@ -131,31 +131,7 @@ class MulQuantizeFuseTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_3(self):
|
||||
graph = build_graph(nodes, edges, {
|
||||
'mul': {'can_be_fused': True},
|
||||
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])},
|
||||
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
|
||||
'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))},
|
||||
'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))},
|
||||
}, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref = build_graph(nodes, edges_ref, {
|
||||
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
|
||||
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])},
|
||||
'mi_o_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[1]], [[0]], [[1]]])},
|
||||
'ma_o_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[0]], [[1]], [[0]]])},
|
||||
'mi_i_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[10]], [[-10]], [[10]]])},
|
||||
'ma_i_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[-10]], [[10]], [[-10]]])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def negative_test_1(self):
|
||||
def test_negative_1(self):
|
||||
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref = build_graph(nodes, edges, nodes_with_edges_only=True)
|
||||
@@ -165,7 +141,7 @@ class MulQuantizeFuseTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def negative_test_2(self):
|
||||
def test_negative_2(self):
|
||||
graph = build_graph(nodes, edges, {'mul': {'can_be_fused': False}}, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref = build_graph(nodes, edges, {'mul': {'can_be_fused': False}}, nodes_with_edges_only=True)
|
||||
@@ -174,3 +150,54 @@ class MulQuantizeFuseTest(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_negative_3(self):
|
||||
graph = build_graph(nodes, edges, {
|
||||
'mul': {'can_be_fused': True},
|
||||
'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])},
|
||||
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
|
||||
'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])},
|
||||
'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])},
|
||||
}, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref = graph.copy()
|
||||
|
||||
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_negative_4(self):
|
||||
graph = build_graph(nodes, edges, {
|
||||
'mul': {'can_be_fused': True},
|
||||
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])},
|
||||
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
|
||||
'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))},
|
||||
'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))},
|
||||
}, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref = graph.copy()
|
||||
|
||||
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_negative_5(self):
|
||||
graph = build_graph(nodes, edges, {
|
||||
'mul': {'can_be_fused': True},
|
||||
'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[0]], [[1]], [[2]]])},
|
||||
'quantize_data': {'shape': np.array([2, 3, 4, 4])},
|
||||
'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))},
|
||||
'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))},
|
||||
}, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref = graph.copy()
|
||||
|
||||
MulFakeQuantizeFuse().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
Reference in New Issue
Block a user