parent
f8b2627c3b
commit
62fba3eadf
@ -98,14 +98,10 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
|||||||
('weights_const', dict(type='Const')),
|
('weights_const', dict(type='Const')),
|
||||||
('weights_d', dict(kind='data')),
|
('weights_d', dict(kind='data')),
|
||||||
('quantize', dict(type='FakeQuantize', levels=lambda x: x is not None and 2 < x <= 256)),
|
('quantize', dict(type='FakeQuantize', levels=lambda x: x is not None and 2 < x <= 256)),
|
||||||
('quantize_d', dict(kind='data')),
|
|
||||||
('convolution', dict())
|
|
||||||
],
|
],
|
||||||
edges=[
|
edges=[
|
||||||
('weights_const', 'weights_d'),
|
('weights_const', 'weights_d'),
|
||||||
('weights_d', 'quantize', {'in': 0}),
|
('weights_d', 'quantize', {'in': 0}),
|
||||||
('quantize', 'quantize_d'),
|
|
||||||
('quantize_d', 'convolution', {'in': 1})
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -118,6 +114,8 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
|||||||
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
|
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
|
||||||
initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2))
|
initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2))
|
||||||
|
|
||||||
|
dst_type = match['weights_const'].value.dtype
|
||||||
|
if np.issubdtype(dst_type, np.floating):
|
||||||
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||||
|
|
||||||
i_min = np.array([0.], dtype=dst_type)
|
i_min = np.array([0.], dtype=dst_type)
|
||||||
|
@ -13,11 +13,11 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from generator import generator, generate
|
||||||
|
|
||||||
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
|
from extensions.back.compress_quantized_weights import CompressQuantizeWeights
|
||||||
from extensions.ops.fakequantize import FakeQuantize
|
from extensions.ops.fakequantize import FakeQuantize
|
||||||
@ -25,7 +25,8 @@ from mo.front.common.partial_infer.eltwise import eltwise_infer
|
|||||||
from mo.graph.graph import Node
|
from mo.graph.graph import Node
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
from mo.utils.unittest.graph import build_graph
|
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, regular_op_with_empty_data, \
|
||||||
|
valued_const_with_data, result, connect
|
||||||
|
|
||||||
nodes_attributes = {
|
nodes_attributes = {
|
||||||
# placeholder
|
# placeholder
|
||||||
@ -33,7 +34,7 @@ nodes_attributes = {
|
|||||||
'placeholder_data': {'kind': 'data'},
|
'placeholder_data': {'kind': 'data'},
|
||||||
|
|
||||||
# weights
|
# weights
|
||||||
'weights_const': {'type': 'Const', 'kind': 'op'},
|
'weights_const': {'type': 'Const', 'kind': 'op', 'value': np.array([], dtype=np.float32)},
|
||||||
'weights_data': {'kind': 'data'},
|
'weights_data': {'kind': 'data'},
|
||||||
|
|
||||||
# quantize
|
# quantize
|
||||||
@ -446,7 +447,6 @@ class WeightQuantizeTest(unittest.TestCase):
|
|||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
def test_accuracy_tensor1(self):
|
def test_accuracy_tensor1(self):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
[1.0, 2.0, 3.0, 4.0]
|
[1.0, 2.0, 3.0, 4.0]
|
||||||
"""
|
"""
|
||||||
@ -814,3 +814,63 @@ class WeightQuantizeTest(unittest.TestCase):
|
|||||||
w_array_ref = Node(graph_ref, 'ac_weights').out_port(0).get_destination().data.get_value()
|
w_array_ref = Node(graph_ref, 'ac_weights').out_port(0).get_destination().data.get_value()
|
||||||
|
|
||||||
self.assertTrue(np.all(w_array == w_array_ref))
|
self.assertTrue(np.all(w_array == w_array_ref))
|
||||||
|
|
||||||
|
|
||||||
|
@generator
|
||||||
|
class CompressionDataTypeTest(unittest.TestCase):
|
||||||
|
@generate(*[
|
||||||
|
('FP32', np.int64),
|
||||||
|
('FP16', np.int64),
|
||||||
|
('FP32', np.int32),
|
||||||
|
('FP16', np.int32),
|
||||||
|
('FP32', np.float64, np.float32),
|
||||||
|
('FP16', np.float64, np.float16),
|
||||||
|
('FP32', np.float32, np.float32),
|
||||||
|
('FP16', np.float32, np.float16),
|
||||||
|
('FP32', np.float16, np.float32),
|
||||||
|
('FP16', np.float16, np.float16),
|
||||||
|
])
|
||||||
|
def test_data_type(self, model_dtype, original, transformed=None):
|
||||||
|
if transformed is None:
|
||||||
|
transformed = original
|
||||||
|
|
||||||
|
nodes = {
|
||||||
|
**valued_const_with_data('weights', np.ones([1, 2, 3, 4], dtype=original)),
|
||||||
|
|
||||||
|
**valued_const_with_data('int_weights', np.ones([1, 2, 3, 4], dtype=np.uint8)),
|
||||||
|
**regular_op_with_shaped_data('cast', [1, 2, 3, 4], {'type': 'Convert', 'dst_type': transformed}),
|
||||||
|
|
||||||
|
**valued_const_with_data('il', np.array([0])),
|
||||||
|
**valued_const_with_data('ih', np.array([254])),
|
||||||
|
**valued_const_with_data('ol', np.array([0])),
|
||||||
|
**valued_const_with_data('oh', np.array([254])),
|
||||||
|
|
||||||
|
**regular_op_with_shaped_data('FQ', [1, 2, 3, 4], {'type': 'FakeQuantize', 'infer': FakeQuantize.infer,
|
||||||
|
'stop_value_propagation': True, 'levels': 255}),
|
||||||
|
**result()
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = build_graph(nodes, [
|
||||||
|
*connect('weights:0', '0:FQ'),
|
||||||
|
*connect('il:0', '1:FQ'),
|
||||||
|
*connect('ih:0', '2:FQ'),
|
||||||
|
*connect('ol:0', '3:FQ'),
|
||||||
|
*connect('oh:0', '4:FQ'),
|
||||||
|
*connect('FQ:0', 'output'),
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
graph.graph['cmd_params'] = Namespace(data_type=model_dtype, keep_shape_ops=True)
|
||||||
|
|
||||||
|
CompressQuantizeWeights().find_and_replace_pattern(graph)
|
||||||
|
graph.clean_up()
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes, [
|
||||||
|
*connect('int_weights:0', '0:cast'),
|
||||||
|
*connect('cast:0', '0:FQ'),
|
||||||
|
*connect('il:0', '1:FQ'),
|
||||||
|
*connect('ih:0', '2:FQ'),
|
||||||
|
*connect('ol:0', '3:FQ'),
|
||||||
|
*connect('oh:0', '4:FQ'),
|
||||||
|
*connect('FQ:0', 'output'),
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
Loading…
Reference in New Issue
Block a user