[ MO ] Keep data type of compressed value (#1143)

JIRA: 34085
This commit is contained in:
Evgenya Stepyreva 2020-06-29 14:56:11 +03:00 committed by GitHub
parent f8b2627c3b
commit 62fba3eadf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 9 deletions

View File

@ -98,14 +98,10 @@ class CompressQuantizeWeights(BackReplacementPattern):
('weights_const', dict(type='Const')), ('weights_const', dict(type='Const')),
('weights_d', dict(kind='data')), ('weights_d', dict(kind='data')),
('quantize', dict(type='FakeQuantize', levels=lambda x: x is not None and 2 < x <= 256)), ('quantize', dict(type='FakeQuantize', levels=lambda x: x is not None and 2 < x <= 256)),
('quantize_d', dict(kind='data')),
('convolution', dict())
], ],
edges=[ edges=[
('weights_const', 'weights_d'), ('weights_const', 'weights_d'),
('weights_d', 'quantize', {'in': 0}), ('weights_d', 'quantize', {'in': 0}),
('quantize', 'quantize_d'),
('quantize_d', 'convolution', {'in': 1})
] ]
) )
@ -118,7 +114,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1)) initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2)) initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2))
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) dst_type = match['weights_const'].value.dtype
if np.issubdtype(dst_type, np.floating):
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
i_min = np.array([0.], dtype=dst_type) i_min = np.array([0.], dtype=dst_type)
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type) i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)

View File

@ -13,11 +13,11 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import unittest import unittest
from argparse import Namespace from argparse import Namespace
import numpy as np import numpy as np
from generator import generator, generate
from extensions.back.compress_quantized_weights import CompressQuantizeWeights from extensions.back.compress_quantized_weights import CompressQuantizeWeights
from extensions.ops.fakequantize import FakeQuantize from extensions.ops.fakequantize import FakeQuantize
@ -25,7 +25,8 @@ from mo.front.common.partial_infer.eltwise import eltwise_infer
from mo.graph.graph import Node from mo.graph.graph import Node
from mo.ops.const import Const from mo.ops.const import Const
from mo.utils.ir_engine.compare_graphs import compare_graphs from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, regular_op_with_empty_data, \
valued_const_with_data, result, connect
nodes_attributes = { nodes_attributes = {
# placeholder # placeholder
@ -33,7 +34,7 @@ nodes_attributes = {
'placeholder_data': {'kind': 'data'}, 'placeholder_data': {'kind': 'data'},
# weights # weights
'weights_const': {'type': 'Const', 'kind': 'op'}, 'weights_const': {'type': 'Const', 'kind': 'op', 'value': np.array([], dtype=np.float32)},
'weights_data': {'kind': 'data'}, 'weights_data': {'kind': 'data'},
# quantize # quantize
@ -446,7 +447,6 @@ class WeightQuantizeTest(unittest.TestCase):
self.assertTrue(flag, resp) self.assertTrue(flag, resp)
def test_accuracy_tensor1(self): def test_accuracy_tensor1(self):
""" """
[1.0, 2.0, 3.0, 4.0] [1.0, 2.0, 3.0, 4.0]
""" """
@ -814,3 +814,63 @@ class WeightQuantizeTest(unittest.TestCase):
w_array_ref = Node(graph_ref, 'ac_weights').out_port(0).get_destination().data.get_value() w_array_ref = Node(graph_ref, 'ac_weights').out_port(0).get_destination().data.get_value()
self.assertTrue(np.all(w_array == w_array_ref)) self.assertTrue(np.all(w_array == w_array_ref))
@generator
class CompressionDataTypeTest(unittest.TestCase):
@generate(*[
('FP32', np.int64),
('FP16', np.int64),
('FP32', np.int32),
('FP16', np.int32),
('FP32', np.float64, np.float32),
('FP16', np.float64, np.float16),
('FP32', np.float32, np.float32),
('FP16', np.float32, np.float16),
('FP32', np.float16, np.float32),
('FP16', np.float16, np.float16),
])
def test_data_type(self, model_dtype, original, transformed=None):
if transformed is None:
transformed = original
nodes = {
**valued_const_with_data('weights', np.ones([1, 2, 3, 4], dtype=original)),
**valued_const_with_data('int_weights', np.ones([1, 2, 3, 4], dtype=np.uint8)),
**regular_op_with_shaped_data('cast', [1, 2, 3, 4], {'type': 'Convert', 'dst_type': transformed}),
**valued_const_with_data('il', np.array([0])),
**valued_const_with_data('ih', np.array([254])),
**valued_const_with_data('ol', np.array([0])),
**valued_const_with_data('oh', np.array([254])),
**regular_op_with_shaped_data('FQ', [1, 2, 3, 4], {'type': 'FakeQuantize', 'infer': FakeQuantize.infer,
'stop_value_propagation': True, 'levels': 255}),
**result()
}
graph = build_graph(nodes, [
*connect('weights:0', '0:FQ'),
*connect('il:0', '1:FQ'),
*connect('ih:0', '2:FQ'),
*connect('ol:0', '3:FQ'),
*connect('oh:0', '4:FQ'),
*connect('FQ:0', 'output'),
], nodes_with_edges_only=True)
graph.graph['cmd_params'] = Namespace(data_type=model_dtype, keep_shape_ops=True)
CompressQuantizeWeights().find_and_replace_pattern(graph)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('int_weights:0', '0:cast'),
*connect('cast:0', '0:FQ'),
*connect('il:0', '1:FQ'),
*connect('ih:0', '2:FQ'),
*connect('ol:0', '3:FQ'),
*connect('oh:0', '4:FQ'),
*connect('FQ:0', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)