diff --git a/model-optimizer/extensions/middle/MulFakeQuantizeFuse.py b/model-optimizer/extensions/middle/MulFakeQuantizeFuse.py index 70045685e95..3bbd670ed16 100644 --- a/model-optimizer/extensions/middle/MulFakeQuantizeFuse.py +++ b/model-optimizer/extensions/middle/MulFakeQuantizeFuse.py @@ -72,6 +72,8 @@ class MulFakeQuantizeFuse(MiddleReplacementPattern): return mul_val = value_port.data.get_value() + if np.any(mul_val <= 0): + return # Direct modifications to quantize 1-st and 2-nd port inputs are performed. # So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same @@ -80,33 +82,6 @@ class MulFakeQuantizeFuse(MiddleReplacementPattern): # TODO: need some special processing for values that exactly equal to threshold - # Need to flip output_low and output_high for those elements that have multiplier < 0 - if np.all(mul_val < 0): - mi_o_node = quantize.in_port(3).get_source() - ma_o_node = quantize.in_port(4).get_source() - - quantize.in_port(3).disconnect() - quantize.in_port(4).disconnect() - - mi_o_node.connect(quantize.in_port(4)) - ma_o_node.connect(quantize.in_port(3)) - - elif np.any(mul_val < 0): - # Flipping values should be done on exclusive inputs of FakeQuantize node, so we duplicate them if needed - resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[3, 4]) - - # Successful flipping will be done on broadcasted arrays - mi_o_val = quantize.in_port(3).data.get_value() - ma_o_val = quantize.in_port(4).data.get_value() - mul_val, mi_o_val, ma_o_val = [np.array(a) for a in np.broadcast_arrays(mul_val, mi_o_val, ma_o_val)] - - neg_idx = np.where(mul_val < 0) - mi_o_val[neg_idx], ma_o_val[neg_idx] = ma_o_val[neg_idx], mi_o_val[neg_idx] - - # TODO: revert broadcasting where unnecessary - quantize.in_port(3).data.set_value(mi_o_val) - quantize.in_port(4).data.set_value(ma_o_val) - quantize.in_port(1).data.set_value(quantize.in_port(1).data.get_value() / mul_val) if quantize.in_node(1).id != quantize.in_node(2).id: quantize.in_port(2).data.set_value(quantize.in_port(2).data.get_value() / mul_val) diff --git a/model-optimizer/unit_tests/extensions/middle/MulQuantizeFuse_test.py b/model-optimizer/unit_tests/extensions/middle/MulQuantizeFuse_test.py index 34b3fccf810..0fafb51addf 100644 --- a/model-optimizer/unit_tests/extensions/middle/MulQuantizeFuse_test.py +++ b/model-optimizer/unit_tests/extensions/middle/MulQuantizeFuse_test.py @@ -110,7 +110,7 @@ class MulQuantizeFuseTest(unittest.TestCase): def test_2(self): graph = build_graph(nodes, edges, { 'mul': {'can_be_fused': True}, - 'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])}, + 'mul_const_data': {'shape': np.array([1]), 'value': np.array([2])}, 'quantize_data': {'shape': np.array([2, 3, 4, 4])}, 'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])}, 'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])}, @@ -118,11 +118,11 @@ class MulQuantizeFuseTest(unittest.TestCase): graph.stage = 'middle' graph_ref = build_graph(nodes, edges_ref, { 'quantize_data': {'shape': np.array([2, 3, 4, 4])}, - 'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])}, - 'mi_o_data': {'shape': np.array([1]), 'value': np.array([1])}, - 'ma_o_data': {'shape': np.array([1]), 'value': np.array([0])}, - 'mi_i_data': {'shape': np.array([1]), 'value': np.array([10])}, - 'ma_i_data': {'shape': np.array([1]), 'value': np.array([-10])}, + 'mul_const_data': {'shape': np.array([1]), 'value': np.array([2])}, + 'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])}, + 'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])}, + 'mi_i_data': {'shape': np.array([1]), 'value': np.array([-5])}, + 'ma_i_data': {'shape': np.array([1]), 'value': np.array([5])}, }, nodes_with_edges_only=True) MulFakeQuantizeFuse().find_and_replace_pattern(graph) @@ -131,31 +131,7 @@ class MulQuantizeFuseTest(unittest.TestCase): self.assertTrue(flag, resp) - def test_3(self): - graph = build_graph(nodes, edges, { - 'mul': {'can_be_fused': True}, - 'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])}, - 'quantize_data': {'shape': np.array([2, 3, 4, 4])}, - 'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))}, - 'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))}, - }, nodes_with_edges_only=True) - graph.stage = 'middle' - graph_ref = build_graph(nodes, edges_ref, { - 'quantize_data': {'shape': np.array([2, 3, 4, 4])}, - 'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])}, - 'mi_o_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[1]], [[0]], [[1]]])}, - 'ma_o_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[0]], [[1]], [[0]]])}, - 'mi_i_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[10]], [[-10]], [[10]]])}, - 'ma_i_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[-10]], [[10]], [[-10]]])}, - }, nodes_with_edges_only=True) - - MulFakeQuantizeFuse().find_and_replace_pattern(graph) - - (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) - - self.assertTrue(flag, resp) - - def negative_test_1(self): + def test_negative_1(self): graph = build_graph(nodes, edges, nodes_with_edges_only=True) graph.stage = 'middle' graph_ref = build_graph(nodes, edges, nodes_with_edges_only=True) @@ -165,7 +141,7 @@ class MulQuantizeFuseTest(unittest.TestCase): self.assertTrue(flag, resp) - def negative_test_2(self): + def test_negative_2(self): graph = build_graph(nodes, edges, {'mul': {'can_be_fused': False}}, nodes_with_edges_only=True) graph.stage = 'middle' graph_ref = build_graph(nodes, edges, {'mul': {'can_be_fused': False}}, nodes_with_edges_only=True) @@ -174,3 +150,54 @@ class MulQuantizeFuseTest(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_negative_3(self): + graph = build_graph(nodes, edges, { + 'mul': {'can_be_fused': True}, + 'mul_const_data': {'shape': np.array([1]), 'value': np.array([-1])}, + 'quantize_data': {'shape': np.array([2, 3, 4, 4])}, + 'mi_o_data': {'shape': np.array([1]), 'value': np.array([0])}, + 'ma_o_data': {'shape': np.array([1]), 'value': np.array([1])}, + }, nodes_with_edges_only=True) + graph.stage = 'middle' + graph_ref = graph.copy() + + MulFakeQuantizeFuse().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + + self.assertTrue(flag, resp) + + def test_negative_4(self): + graph = build_graph(nodes, edges, { + 'mul': {'can_be_fused': True}, + 'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]])}, + 'quantize_data': {'shape': np.array([2, 3, 4, 4])}, + 'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))}, + 'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))}, + }, nodes_with_edges_only=True) + graph.stage = 'middle' + graph_ref = graph.copy() + + MulFakeQuantizeFuse().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + + self.assertTrue(flag, resp) + + def test_negative_5(self): + graph = build_graph(nodes, edges, { + 'mul': {'can_be_fused': True}, + 'mul_const_data': {'shape': np.array([3, 1, 1]), 'value': np.array([[[0]], [[1]], [[2]]])}, + 'quantize_data': {'shape': np.array([2, 3, 4, 4])}, + 'mi_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))}, + 'ma_o_data': {'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))}, + }, nodes_with_edges_only=True) + graph.stage = 'middle' + graph_ref = graph.copy() + + MulFakeQuantizeFuse().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + + self.assertTrue(flag, resp)