From 62fba3eadf44b1ca2d9fa7cf3699a0b354360d37 Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Mon, 29 Jun 2020 14:56:11 +0300 Subject: [PATCH] [ MO ] Keep data type of compressed value (#1143) JIRA: 34085 --- .../back/compress_quantized_weights.py | 8 +-- .../back/compress_quantized_weights_test.py | 68 +++++++++++++++++-- 2 files changed, 67 insertions(+), 9 deletions(-) diff --git a/model-optimizer/extensions/back/compress_quantized_weights.py b/model-optimizer/extensions/back/compress_quantized_weights.py index 2bd2e40433c..df2500cec97 100644 --- a/model-optimizer/extensions/back/compress_quantized_weights.py +++ b/model-optimizer/extensions/back/compress_quantized_weights.py @@ -98,14 +98,10 @@ class CompressQuantizeWeights(BackReplacementPattern): ('weights_const', dict(type='Const')), ('weights_d', dict(kind='data')), ('quantize', dict(type='FakeQuantize', levels=lambda x: x is not None and 2 < x <= 256)), - ('quantize_d', dict(kind='data')), - ('convolution', dict()) ], edges=[ ('weights_const', 'weights_d'), ('weights_d', 'quantize', {'in': 0}), - ('quantize', 'quantize_d'), - ('quantize_d', 'convolution', {'in': 1}) ] ) @@ -118,7 +114,9 @@ class CompressQuantizeWeights(BackReplacementPattern): initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1)) initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2)) - dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) + dst_type = match['weights_const'].value.dtype + if np.issubdtype(dst_type, np.floating): + dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) i_min = np.array([0.], dtype=dst_type) i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type) diff --git a/model-optimizer/extensions/back/compress_quantized_weights_test.py b/model-optimizer/extensions/back/compress_quantized_weights_test.py index eda2e77966e..dedd4ea8942 100644 --- a/model-optimizer/extensions/back/compress_quantized_weights_test.py +++ b/model-optimizer/extensions/back/compress_quantized_weights_test.py @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ - import unittest from argparse import Namespace import numpy as np +from generator import generator, generate from extensions.back.compress_quantized_weights import CompressQuantizeWeights from extensions.ops.fakequantize import FakeQuantize @@ -25,7 +25,8 @@ from mo.front.common.partial_infer.eltwise import eltwise_infer from mo.graph.graph import Node from mo.ops.const import Const from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph +from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, regular_op_with_empty_data, \ + valued_const_with_data, result, connect nodes_attributes = { # placeholder @@ -33,7 +34,7 @@ nodes_attributes = { 'placeholder_data': {'kind': 'data'}, # weights - 'weights_const': {'type': 'Const', 'kind': 'op'}, + 'weights_const': {'type': 'Const', 'kind': 'op', 'value': np.array([], dtype=np.float32)}, 'weights_data': {'kind': 'data'}, # quantize @@ -446,7 +447,6 @@ class WeightQuantizeTest(unittest.TestCase): self.assertTrue(flag, resp) def test_accuracy_tensor1(self): - """ [1.0, 2.0, 3.0, 4.0] """ @@ -814,3 +814,63 @@ class WeightQuantizeTest(unittest.TestCase): w_array_ref = Node(graph_ref, 'ac_weights').out_port(0).get_destination().data.get_value() self.assertTrue(np.all(w_array == w_array_ref)) + + +@generator +class CompressionDataTypeTest(unittest.TestCase): + @generate(*[ + ('FP32', np.int64), + ('FP16', np.int64), + ('FP32', np.int32), + ('FP16', np.int32), + ('FP32', np.float64, np.float32), + ('FP16', np.float64, np.float16), + ('FP32', np.float32, np.float32), + ('FP16', np.float32, np.float16), + ('FP32', np.float16, np.float32), + ('FP16', np.float16, np.float16), + ]) + def test_data_type(self, model_dtype, original, transformed=None): + if transformed is None: + transformed = original + + nodes = { + **valued_const_with_data('weights', np.ones([1, 2, 3, 4], dtype=original)), + + **valued_const_with_data('int_weights', np.ones([1, 2, 3, 4], dtype=np.uint8)), + **regular_op_with_shaped_data('cast', [1, 2, 3, 4], {'type': 'Convert', 'dst_type': transformed}), + + **valued_const_with_data('il', np.array([0])), + **valued_const_with_data('ih', np.array([254])), + **valued_const_with_data('ol', np.array([0])), + **valued_const_with_data('oh', np.array([254])), + + **regular_op_with_shaped_data('FQ', [1, 2, 3, 4], {'type': 'FakeQuantize', 'infer': FakeQuantize.infer, + 'stop_value_propagation': True, 'levels': 255}), + **result() + } + + graph = build_graph(nodes, [ + *connect('weights:0', '0:FQ'), + *connect('il:0', '1:FQ'), + *connect('ih:0', '2:FQ'), + *connect('ol:0', '3:FQ'), + *connect('oh:0', '4:FQ'), + *connect('FQ:0', 'output'), + ], nodes_with_edges_only=True) + graph.graph['cmd_params'] = Namespace(data_type=model_dtype, keep_shape_ops=True) + + CompressQuantizeWeights().find_and_replace_pattern(graph) + graph.clean_up() + + graph_ref = build_graph(nodes, [ + *connect('int_weights:0', '0:cast'), + *connect('cast:0', '0:FQ'), + *connect('il:0', '1:FQ'), + *connect('ih:0', '2:FQ'), + *connect('ol:0', '3:FQ'), + *connect('oh:0', '4:FQ'), + *connect('FQ:0', 'output'), + ], nodes_with_edges_only=True) + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp)