[ MO ] Keep data type of compressed value (#1143)

JIRA: 34085
This commit is contained in:
Evgenya Stepyreva 2020-06-29 14:56:11 +03:00 committed by GitHub
parent f8b2627c3b
commit 62fba3eadf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 9 deletions

View File

@ -98,14 +98,10 @@ class CompressQuantizeWeights(BackReplacementPattern):
('weights_const', dict(type='Const')),
('weights_d', dict(kind='data')),
('quantize', dict(type='FakeQuantize', levels=lambda x: x is not None and 2 < x <= 256)),
('quantize_d', dict(kind='data')),
('convolution', dict())
],
edges=[
('weights_const', 'weights_d'),
('weights_d', 'quantize', {'in': 0}),
('quantize', 'quantize_d'),
('quantize_d', 'convolution', {'in': 1})
]
)
@ -118,7 +114,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2))
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
dst_type = match['weights_const'].value.dtype
if np.issubdtype(dst_type, np.floating):
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
i_min = np.array([0.], dtype=dst_type)
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)

View File

@ -13,11 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from argparse import Namespace
import numpy as np
from generator import generator, generate
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
from extensions.ops.fakequantize import FakeQuantize
@ -25,7 +25,8 @@ from mo.front.common.partial_infer.eltwise import eltwise_infer
from mo.graph.graph import Node
from mo.ops.const import Const
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, regular_op_with_empty_data, \
valued_const_with_data, result, connect
nodes_attributes = {
# placeholder
@ -33,7 +34,7 @@ nodes_attributes = {
'placeholder_data': {'kind': 'data'},
# weights
'weights_const': {'type': 'Const', 'kind': 'op'},
'weights_const': {'type': 'Const', 'kind': 'op', 'value': np.array([], dtype=np.float32)},
'weights_data': {'kind': 'data'},
# quantize
@ -446,7 +447,6 @@ class WeightQuantizeTest(unittest.TestCase):
self.assertTrue(flag, resp)
def test_accuracy_tensor1(self):
"""
[1.0, 2.0, 3.0, 4.0]
"""
@ -814,3 +814,63 @@ class WeightQuantizeTest(unittest.TestCase):
w_array_ref = Node(graph_ref, 'ac_weights').out_port(0).get_destination().data.get_value()
self.assertTrue(np.all(w_array == w_array_ref))
@generator
class CompressionDataTypeTest(unittest.TestCase):
@generate(*[
('FP32', np.int64),
('FP16', np.int64),
('FP32', np.int32),
('FP16', np.int32),
('FP32', np.float64, np.float32),
('FP16', np.float64, np.float16),
('FP32', np.float32, np.float32),
('FP16', np.float32, np.float16),
('FP32', np.float16, np.float32),
('FP16', np.float16, np.float16),
])
def test_data_type(self, model_dtype, original, transformed=None):
if transformed is None:
transformed = original
nodes = {
**valued_const_with_data('weights', np.ones([1, 2, 3, 4], dtype=original)),
**valued_const_with_data('int_weights', np.ones([1, 2, 3, 4], dtype=np.uint8)),
**regular_op_with_shaped_data('cast', [1, 2, 3, 4], {'type': 'Convert', 'dst_type': transformed}),
**valued_const_with_data('il', np.array([0])),
**valued_const_with_data('ih', np.array([254])),
**valued_const_with_data('ol', np.array([0])),
**valued_const_with_data('oh', np.array([254])),
**regular_op_with_shaped_data('FQ', [1, 2, 3, 4], {'type': 'FakeQuantize', 'infer': FakeQuantize.infer,
'stop_value_propagation': True, 'levels': 255}),
**result()
}
graph = build_graph(nodes, [
*connect('weights:0', '0:FQ'),
*connect('il:0', '1:FQ'),
*connect('ih:0', '2:FQ'),
*connect('ol:0', '3:FQ'),
*connect('oh:0', '4:FQ'),
*connect('FQ:0', 'output'),
], nodes_with_edges_only=True)
graph.graph['cmd_params'] = Namespace(data_type=model_dtype, keep_shape_ops=True)
CompressQuantizeWeights().find_and_replace_pattern(graph)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('int_weights:0', '0:cast'),
*connect('cast:0', '0:FQ'),
*connect('il:0', '1:FQ'),
*connect('ih:0', '2:FQ'),
*connect('ol:0', '3:FQ'),
*connect('oh:0', '4:FQ'),
*connect('FQ:0', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)