parent
f8b2627c3b
commit
62fba3eadf
@ -98,14 +98,10 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
('weights_const', dict(type='Const')),
|
||||
('weights_d', dict(kind='data')),
|
||||
('quantize', dict(type='FakeQuantize', levels=lambda x: x is not None and 2 < x <= 256)),
|
||||
('quantize_d', dict(kind='data')),
|
||||
('convolution', dict())
|
||||
],
|
||||
edges=[
|
||||
('weights_const', 'weights_d'),
|
||||
('weights_d', 'quantize', {'in': 0}),
|
||||
('quantize', 'quantize_d'),
|
||||
('quantize_d', 'convolution', {'in': 1})
|
||||
]
|
||||
)
|
||||
|
||||
@ -118,7 +114,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
|
||||
initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2))
|
||||
|
||||
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
dst_type = match['weights_const'].value.dtype
|
||||
if np.issubdtype(dst_type, np.floating):
|
||||
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
|
||||
i_min = np.array([0.], dtype=dst_type)
|
||||
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)
|
||||
|
@ -13,11 +13,11 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
|
||||
from extensions.ops.fakequantize import FakeQuantize
|
||||
@ -25,7 +25,8 @@ from mo.front.common.partial_infer.eltwise import eltwise_infer
|
||||
from mo.graph.graph import Node
|
||||
from mo.ops.const import Const
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, regular_op_with_empty_data, \
|
||||
valued_const_with_data, result, connect
|
||||
|
||||
nodes_attributes = {
|
||||
# placeholder
|
||||
@ -33,7 +34,7 @@ nodes_attributes = {
|
||||
'placeholder_data': {'kind': 'data'},
|
||||
|
||||
# weights
|
||||
'weights_const': {'type': 'Const', 'kind': 'op'},
|
||||
'weights_const': {'type': 'Const', 'kind': 'op', 'value': np.array([], dtype=np.float32)},
|
||||
'weights_data': {'kind': 'data'},
|
||||
|
||||
# quantize
|
||||
@ -446,7 +447,6 @@ class WeightQuantizeTest(unittest.TestCase):
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_accuracy_tensor1(self):
|
||||
|
||||
"""
|
||||
[1.0, 2.0, 3.0, 4.0]
|
||||
"""
|
||||
@ -814,3 +814,63 @@ class WeightQuantizeTest(unittest.TestCase):
|
||||
w_array_ref = Node(graph_ref, 'ac_weights').out_port(0).get_destination().data.get_value()
|
||||
|
||||
self.assertTrue(np.all(w_array == w_array_ref))
|
||||
|
||||
|
||||
@generator
|
||||
class CompressionDataTypeTest(unittest.TestCase):
|
||||
@generate(*[
|
||||
('FP32', np.int64),
|
||||
('FP16', np.int64),
|
||||
('FP32', np.int32),
|
||||
('FP16', np.int32),
|
||||
('FP32', np.float64, np.float32),
|
||||
('FP16', np.float64, np.float16),
|
||||
('FP32', np.float32, np.float32),
|
||||
('FP16', np.float32, np.float16),
|
||||
('FP32', np.float16, np.float32),
|
||||
('FP16', np.float16, np.float16),
|
||||
])
|
||||
def test_data_type(self, model_dtype, original, transformed=None):
|
||||
if transformed is None:
|
||||
transformed = original
|
||||
|
||||
nodes = {
|
||||
**valued_const_with_data('weights', np.ones([1, 2, 3, 4], dtype=original)),
|
||||
|
||||
**valued_const_with_data('int_weights', np.ones([1, 2, 3, 4], dtype=np.uint8)),
|
||||
**regular_op_with_shaped_data('cast', [1, 2, 3, 4], {'type': 'Convert', 'dst_type': transformed}),
|
||||
|
||||
**valued_const_with_data('il', np.array([0])),
|
||||
**valued_const_with_data('ih', np.array([254])),
|
||||
**valued_const_with_data('ol', np.array([0])),
|
||||
**valued_const_with_data('oh', np.array([254])),
|
||||
|
||||
**regular_op_with_shaped_data('FQ', [1, 2, 3, 4], {'type': 'FakeQuantize', 'infer': FakeQuantize.infer,
|
||||
'stop_value_propagation': True, 'levels': 255}),
|
||||
**result()
|
||||
}
|
||||
|
||||
graph = build_graph(nodes, [
|
||||
*connect('weights:0', '0:FQ'),
|
||||
*connect('il:0', '1:FQ'),
|
||||
*connect('ih:0', '2:FQ'),
|
||||
*connect('ol:0', '3:FQ'),
|
||||
*connect('oh:0', '4:FQ'),
|
||||
*connect('FQ:0', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
graph.graph['cmd_params'] = Namespace(data_type=model_dtype, keep_shape_ops=True)
|
||||
|
||||
CompressQuantizeWeights().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('int_weights:0', '0:cast'),
|
||||
*connect('cast:0', '0:FQ'),
|
||||
*connect('il:0', '1:FQ'),
|
||||
*connect('ih:0', '2:FQ'),
|
||||
*connect('ol:0', '3:FQ'),
|
||||
*connect('oh:0', '4:FQ'),
|
||||
*connect('FQ:0', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user