INT4 compression (#4912)
* INT4 compression * packing tests * (U)INT4 IR Reader enabling * Cast test header added * Fix graph.clean_up * Review discussions addressed: comments clarification, variable naming, fuller test coverage
This commit is contained in:
parent
8b4837ea62
commit
2e75aafbe2
@ -22,7 +22,7 @@ from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Sub, Div, Mul, Negative
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np, np_data_type_to_destination_type
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np, np_data_type_to_destination_type, packed_I4
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
@ -104,6 +104,12 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
|
||||
force_clean_up = True
|
||||
|
||||
QUANTIZATION_MAP = {
|
||||
# max_levels: (np_dtype, quantization_mode)
|
||||
256: (np.int8, "signed"),
|
||||
16: (packed_I4, "signed"),
|
||||
}
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
@ -118,7 +124,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def quantize_data(fake_quantize: Node, dst_type: type):
|
||||
def quantize_data(fake_quantize: Node, dst_type: type, quantized_type: type, mode: str):
|
||||
graph = fake_quantize.graph
|
||||
name = fake_quantize.soft_get('name', fake_quantize.id)
|
||||
levels = fake_quantize.levels
|
||||
@ -131,8 +137,12 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
fake_quantize.in_port(2).get_connection().set_destination(quantize.in_port(2))
|
||||
|
||||
# calculate output limits for quantized weights
|
||||
i_min = np.array([-(levels // 2)], dtype=dst_type)
|
||||
assert mode in ["signed", "unsigned"]
|
||||
i_min_value = -(levels // 2) if mode == "signed" else 0
|
||||
|
||||
i_min = np.array([i_min_value], dtype=dst_type)
|
||||
i_max = np.array(levels + i_min - 1, dtype=dst_type)
|
||||
|
||||
assert i_max - i_min == levels - 1
|
||||
out_low = Const(graph, dict(name=name + '/Copy/out_low', value=i_min)).create_node()
|
||||
out_high = Const(graph, dict(name=name + '/Copy/out_high', value=i_max)).create_node()
|
||||
@ -144,19 +154,20 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
|
||||
original_const = quantize.in_port(0).get_source().node
|
||||
quantized_data_name = original_const.soft_get('name', original_const.id) + '/quantized'
|
||||
cast = Cast(graph, dict(name=quantized_data_name, dst_type=np.int8, stop_value_propagation=False)).create_node()
|
||||
cast = Cast(graph, dict(name=quantized_data_name, dst_type=quantized_type,
|
||||
stop_value_propagation=False)).create_node()
|
||||
|
||||
quantize.out_port(0).connect(cast.in_port(0))
|
||||
|
||||
cast.out_port(0).connect(fake_quantize.in_port(0))
|
||||
|
||||
@staticmethod
|
||||
def dequantize_data(fake_quantize: Node, dst_type: type) -> Node:
|
||||
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node:
|
||||
graph = fake_quantize.graph
|
||||
quantized_data = fake_quantize.in_port(0).get_source().node
|
||||
name = fake_quantize.soft_get('name', fake_quantize.id)
|
||||
|
||||
assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == np.int8, \
|
||||
assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
|
||||
'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))
|
||||
|
||||
dequantizing_cast = Cast(graph, dict(
|
||||
@ -212,5 +223,11 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
if np.issubdtype(dst_type, np.floating):
|
||||
dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
|
||||
self.quantize_data(fake_quantize, dst_type)
|
||||
self.dequantize_data(fake_quantize, dst_type)
|
||||
quantized_type, mode = None, None
|
||||
for quantization_levels in sorted(self.QUANTIZATION_MAP):
|
||||
if quantization_levels >= fake_quantize.levels:
|
||||
quantized_type, mode = self.QUANTIZATION_MAP[quantization_levels]
|
||||
break
|
||||
|
||||
self.quantize_data(fake_quantize, dst_type, quantized_type, mode)
|
||||
self.dequantize_data(fake_quantize, dst_type, quantized_type)
|
||||
|
@ -83,7 +83,7 @@ class CompressionQuantizeDequantizeSeparateTest(unittest.TestCase):
|
||||
self.assertEqual(len(fq_nodes), 1, error_message.format('before', len(fq_nodes)))
|
||||
fake_quantize = fq_nodes[0]
|
||||
|
||||
CompressQuantizeWeights.quantize_data(fake_quantize, original_type)
|
||||
CompressQuantizeWeights.quantize_data(fake_quantize, original_type, np.int8, "signed")
|
||||
graph.clean_up()
|
||||
|
||||
fq_nodes = graph.get_op_nodes(type='FakeQuantize')
|
||||
@ -124,7 +124,7 @@ class CompressionQuantizeDequantizeSeparateTest(unittest.TestCase):
|
||||
self.assertEqual(len(cast_nodes), 1, error_message.format('Convert', 'before', len(cast_nodes)))
|
||||
cast_nodes[0]['need_shape_inference'] = True
|
||||
|
||||
CompressQuantizeWeights.dequantize_data(fq_nodes[0], original_type)
|
||||
CompressQuantizeWeights.dequantize_data(fq_nodes[0], original_type, np.int8)
|
||||
graph.clean_up()
|
||||
|
||||
fq_nodes = graph.get_op_nodes(type='FakeQuantize')
|
||||
|
@ -14,10 +14,12 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
import logging as log
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.middle.passes.convert_data_type import np_data_type_to_precision, convert_blob, np_data_type_to_destination_type
|
||||
from mo.middle.passes.convert_data_type import np_data_type_to_precision, convert_blob, \
|
||||
np_data_type_to_destination_type, packed_I4, packed_U4
|
||||
from mo.ops.op import Op
|
||||
from mo.utils.utils import refer_to_faq_msg
|
||||
|
||||
@ -47,23 +49,77 @@ class Cast(Op):
|
||||
assert node.has_valid('dst_type'), 'Destination type of "Cast" operation should be extracted earlier'
|
||||
node.out_port(0).set_data_type(node.dst_type)
|
||||
|
||||
@staticmethod
|
||||
def helper_value_propagation(node_name, value, dst_type):
|
||||
new_blob, finite_match_count, zero_match_count = convert_blob(value, dst_type)
|
||||
|
||||
if finite_match_count:
|
||||
log.error("{} elements of {} were clipped to infinity while converting an input blob for node '{}' to {}."
|
||||
" ".format(finite_match_count, new_blob.size, node_name, dst_type) + refer_to_faq_msg(76))
|
||||
if zero_match_count:
|
||||
log.warning("{} elements of {} were clipped to zero while converting an input blob for node '{}' to {}."
|
||||
" ".format(zero_match_count, new_blob.size, node_name, dst_type) + refer_to_faq_msg(77))
|
||||
return new_blob
|
||||
|
||||
@staticmethod
|
||||
def custom_type_casting_and_packing(node: Node, value, dst_type):
|
||||
"""
|
||||
Custom types are not supported by numpy but we still need to write it to the .bin file in a compact way.
|
||||
To do so we prepare bit representation of int4/uint4 values and store them in a numpy friendly data type.
|
||||
We pack int4/uint4 values into uint8 type (two int4/uint4 numbers fit in uint8).
|
||||
If the number of elements in the blob is odd we pad them with zero value to be able to fit the bit sequence
|
||||
into the uint8 array.
|
||||
Example: we need to represent 5 elements of int4 dtype
|
||||
we would pad them to 6 element with the last element as zero and we would pack them into 3 uint8 values
|
||||
"""
|
||||
assert dst_type in [packed_U4, packed_I4]
|
||||
|
||||
minimum_regular_dtype = np.uint8 if dst_type == packed_U4 else np.int8
|
||||
# initial casing from the source type to the numpy-friendly type which could absorb all the values of dst_type
|
||||
casted_to_regular_type = Cast.helper_value_propagation(
|
||||
node.soft_get('name', node.id), value, minimum_regular_dtype)
|
||||
|
||||
# packing the values
|
||||
data_shape = node.out_port(0).data.get_shape()
|
||||
assert data_shape is not None
|
||||
data_size = np.prod(data_shape)
|
||||
|
||||
num_bits = 4
|
||||
assert num_bits < 8 and 8 % num_bits == 0, "Packing algorithm for the data types stored in 1, 2 or 4 bits"
|
||||
num_values_fitting_into_uint8 = 8 // num_bits
|
||||
pad = (-data_size) % num_values_fitting_into_uint8
|
||||
|
||||
flattened = casted_to_regular_type.flatten()
|
||||
padded = np.concatenate((flattened, np.zeros([pad], dtype=minimum_regular_dtype)))
|
||||
assert np.prod(padded.shape) % num_values_fitting_into_uint8 == 0
|
||||
|
||||
bit_order_little = (padded[:, None] & (1 << np.arange(num_bits)) > 0).astype(np.uint8)
|
||||
bit_order_big = np.flip(bit_order_little, axis=1)
|
||||
bit_order_big_flattened = bit_order_big.flatten()
|
||||
packed = np.packbits(bit_order_big_flattened, bitorder='big')
|
||||
|
||||
node.out_node(0)['force_shape'] = data_shape.copy()
|
||||
node.out_node(0)['force_type'] = np_data_type_to_precision(dst_type)
|
||||
node.out_port(0).data.set_value(packed)
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
assert node.has_valid('dst_type'), 'Destination type of "Cast" operation should be extracted earlier'
|
||||
dst_type = node.dst_type
|
||||
copy_shape_infer(node)
|
||||
if node.has_and_set('stop_value_propagation'):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
dst_type = node.soft_get('dst_type', None)
|
||||
|
||||
assert dst_type is not None, \
|
||||
'Destination type of "Cast" operation should be extracted earlier, but it`s not for node: ' + node_name
|
||||
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
assert input_shape is not None
|
||||
node.out_port(0).data.set_shape(input_shape)
|
||||
|
||||
value = node.in_port(0).data.get_value()
|
||||
if value is None or node.has_and_set('stop_value_propagation'):
|
||||
return
|
||||
if node.in_node(0).has_valid('value'):
|
||||
new_blob, finite_match_count, zero_match_count = convert_blob(node.in_node(0).value, dst_type)
|
||||
node.out_port(0).data.set_value(new_blob)
|
||||
|
||||
if finite_match_count:
|
||||
log.error(
|
||||
("{} elements of {} were clipped to infinity while converting an input blob for node '{}' to {}. " +
|
||||
refer_to_faq_msg(76)).format(finite_match_count, new_blob.size, node.name, dst_type))
|
||||
if zero_match_count:
|
||||
log.warning(
|
||||
("{} elements of {} were clipped to zero while converting an input blob for node '{}' to {}. " +
|
||||
refer_to_faq_msg(77)).format(zero_match_count, new_blob.size, node.name, dst_type))
|
||||
|
||||
if dst_type in [packed_U4, packed_I4]: # custom types conversion
|
||||
Cast.custom_type_casting_and_packing(node, value, dst_type)
|
||||
else:
|
||||
node.out_port(0).data.set_value(
|
||||
Cast.helper_value_propagation(node_name, value, dst_type))
|
||||
|
126
model-optimizer/extensions/ops/cast_test.py
Normal file
126
model-optimizer/extensions/ops/cast_test.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from mo.middle.passes.convert_data_type import packed_U4, packed_I4
|
||||
from mo.middle.passes.infer import partial_infer
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import valued_const_with_data, regular_op_with_empty_data, \
|
||||
result, build_graph, connect
|
||||
|
||||
nodes = lambda value, dst_type: {
|
||||
**valued_const_with_data('value', np.array(value)),
|
||||
**regular_op_with_empty_data('convert', {'dst_type': dst_type, 'infer': Cast.infer}),
|
||||
**result(),
|
||||
}
|
||||
|
||||
|
||||
@generator
|
||||
class CastTest(unittest.TestCase):
|
||||
"""
|
||||
Example of checking:
|
||||
7 == 0111, padded to 0111 0000, results in 112
|
||||
7 == 0111, 8 == 1000 packed to 0111 1000, results in 120
|
||||
|
||||
-8 == 1000, padded to 1000 0000, results in 128
|
||||
"""
|
||||
|
||||
@generate(*[
|
||||
([0], [0], packed_U4),
|
||||
([1], [16], packed_U4),
|
||||
([2], [32], packed_U4),
|
||||
([3], [48], packed_U4),
|
||||
([4], [64], packed_U4),
|
||||
([5], [80], packed_U4),
|
||||
([6], [96], packed_U4),
|
||||
([7], [112], packed_U4),
|
||||
([8], [128], packed_U4),
|
||||
([9], [144], packed_U4),
|
||||
([10], [160], packed_U4),
|
||||
([11], [176], packed_U4),
|
||||
([12], [192], packed_U4),
|
||||
([13], [208], packed_U4),
|
||||
([14], [224], packed_U4),
|
||||
([15], [240], packed_U4),
|
||||
|
||||
([0, 15], [15], packed_U4),
|
||||
([1, 14], [30], packed_U4),
|
||||
([2, 13], [45], packed_U4),
|
||||
([3, 12], [60], packed_U4),
|
||||
([4, 11], [75], packed_U4),
|
||||
([5, 10], [90], packed_U4),
|
||||
([6, 9], [105], packed_U4),
|
||||
([7, 8], [120], packed_U4),
|
||||
([8, 7], [135], packed_U4),
|
||||
([9, 6], [150], packed_U4),
|
||||
([10, 5], [165], packed_U4),
|
||||
([11, 4], [180], packed_U4),
|
||||
([12, 3], [195], packed_U4),
|
||||
([13, 2], [210], packed_U4),
|
||||
([14, 1], [225], packed_U4),
|
||||
([15, 0], [240], packed_U4),
|
||||
|
||||
([-8], [128], packed_I4),
|
||||
([-7], [144], packed_I4),
|
||||
([-6], [160], packed_I4),
|
||||
([-5], [176], packed_I4),
|
||||
([-4], [192], packed_I4),
|
||||
([-3], [208], packed_I4),
|
||||
([-2], [224], packed_I4),
|
||||
([-1], [240], packed_I4),
|
||||
([0], [0], packed_I4),
|
||||
([1], [16], packed_I4),
|
||||
([2], [32], packed_I4),
|
||||
([3], [48], packed_I4),
|
||||
([4], [64], packed_I4),
|
||||
([5], [80], packed_I4),
|
||||
([6], [96], packed_I4),
|
||||
([7], [112], packed_I4),
|
||||
|
||||
([-8, 7], [135], packed_I4),
|
||||
([-7, 6], [150], packed_I4),
|
||||
([-6, 5], [165], packed_I4),
|
||||
([-5, 4], [180], packed_I4),
|
||||
([-4, 3], [195], packed_I4),
|
||||
([-3, 2], [210], packed_I4),
|
||||
([-2, 1], [225], packed_I4),
|
||||
([-1, 0], [240], packed_I4),
|
||||
([0, -1], [15], packed_I4),
|
||||
([1, -2], [30], packed_I4),
|
||||
([2, -3], [45], packed_I4),
|
||||
([3, -4], [60], packed_I4),
|
||||
([4, -5], [75], packed_I4),
|
||||
([5, -6], [90], packed_I4),
|
||||
([6, -7], [105], packed_I4),
|
||||
([7, -8], [120], packed_I4),
|
||||
])
|
||||
def test_custom_value_propagation(self, value, expected, custom_dtype):
|
||||
graph = build_graph(nodes(value, custom_dtype), [
|
||||
*connect('value', 'convert'), *connect('convert', 'output'),
|
||||
])
|
||||
partial_infer(graph)
|
||||
|
||||
graph_ref = build_graph(nodes(value, custom_dtype), [
|
||||
*connect('value', 'convert'), *connect('convert', 'output')],
|
||||
{'convert_d': {'force_type': custom_dtype, 'force_shape': np.array(value).shape,
|
||||
'value': expected}})
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -100,10 +100,8 @@ class Port:
|
||||
if self.node.graph.stage == 'front':
|
||||
return None
|
||||
else:
|
||||
if self.type == 'in':
|
||||
return self.node.in_node(self.idx, control_flow=self.control_flow).shape
|
||||
else:
|
||||
return self.node.out_node(self.idx, control_flow=self.control_flow).shape
|
||||
node_caller = self.node.in_node if self.type == 'in' else self.node.out_node
|
||||
return node_caller(self.idx, control_flow=self.control_flow).shape
|
||||
|
||||
def _set_shape(self, shape):
|
||||
if self.node.graph.stage == 'front':
|
||||
@ -114,7 +112,8 @@ class Port:
|
||||
self.node.in_node(self.idx, control_flow=self.control_flow).shape = int64_array(shape)
|
||||
else:
|
||||
data_node = self.node.out_node(self.idx, control_flow=self.control_flow)
|
||||
assert data_node.value is None or np.array_equal(data_node.shape, int64_array(shape))
|
||||
assert data_node.value is None or \
|
||||
np.array_equal(data_node.soft_get('force_shape', data_node.shape), int64_array(shape))
|
||||
self.node.out_node(self.idx, control_flow=self.control_flow).shape = int64_array(shape)
|
||||
|
||||
def _get_value(self):
|
||||
@ -135,24 +134,21 @@ class Port:
|
||||
if self.node.graph.stage == 'front':
|
||||
raise Error("set_value is not applicable for graph front phase")
|
||||
else:
|
||||
if self.type == 'in':
|
||||
data_node = self.node.in_node(self.idx, control_flow=self.control_flow)
|
||||
const_node = data_node.in_node(control_flow=self.control_flow)
|
||||
data_node_caller = self.node.in_node if self.type == 'in' else self.node.out_node
|
||||
data_node = data_node_caller(self.idx, control_flow=self.control_flow)
|
||||
const_node = data_node.in_node(control_flow=self.control_flow) if self.type == 'in' else self.node
|
||||
|
||||
# Set value to data node
|
||||
data_node.value = value
|
||||
data_node.shape = int64_array(value.shape)
|
||||
force_shape = data_node.soft_get('force_shape', const_node.soft_get('force_shape', None))
|
||||
shape = int64_array(value.shape if force_shape is None else force_shape)
|
||||
|
||||
# Set value to constant producer
|
||||
if const_node.soft_get('type') == 'Const':
|
||||
const_node.value = value
|
||||
const_node.shape = int64_array(value.shape)
|
||||
else:
|
||||
self.node.out_node(self.idx, control_flow=self.control_flow).value = value
|
||||
self.node.out_node(self.idx, control_flow=self.control_flow).shape = int64_array(value.shape)
|
||||
if self.node.soft_get('type') == 'Const':
|
||||
self.node.value = value
|
||||
self.node.shape = int64_array(value.shape)
|
||||
# Set value to data node
|
||||
data_node.value = value
|
||||
data_node.shape = shape
|
||||
|
||||
# Set value to constant producer
|
||||
if const_node.soft_get('type') == 'Const':
|
||||
const_node.value = value
|
||||
const_node.shape = shape
|
||||
|
||||
def _get_attr(self, item: str):
|
||||
if self.node.graph.stage == 'front':
|
||||
|
@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.unittest.graph import build_graph, regular_op
|
||||
from mo.utils.unittest.graph import build_graph, regular_op, valued_const_with_data, result, connect
|
||||
|
||||
nodes = {
|
||||
**regular_op('input', {'type': 'Parameter'}),
|
||||
@ -139,3 +139,29 @@ class TestPortMethods(unittest.TestCase):
|
||||
op1_node = Node(graph, 'Op1')
|
||||
op1_node.out_port(0).disconnect()
|
||||
self.assertTrue(op1_node.out_port(0).disconnected())
|
||||
|
||||
|
||||
class TestForceShape(unittest.TestCase):
|
||||
def test_set_value_and_shape_with_force_shape_attribute_in_op(self):
|
||||
import numpy as np
|
||||
graph = build_graph({**valued_const_with_data('const', np.array([1, 2, 3])), **result()},
|
||||
[*connect('const', 'output')])
|
||||
|
||||
node = Node(graph, 'const')
|
||||
node['force_shape'] = np.array([2, 5, 7], dtype=np.int64)
|
||||
node.out_port(0).data.set_value(np.zeros(35))
|
||||
self.assertTrue(np.array_equal(node.out_port(0).data.get_shape(), np.array([2, 5, 7], dtype=np.int64)),
|
||||
"node.out_port(0).data.get_shape()={} != [2, 5, 7]".format(node.out_port(0).data.get_shape()))
|
||||
|
||||
def test_set_value_and_shape_with_force_shape_attribute_in_data(self):
|
||||
import numpy as np
|
||||
graph = build_graph({**valued_const_with_data('const', np.array([1, 2, 3])), **result()},
|
||||
[*connect('const', 'output')])
|
||||
|
||||
node = Node(graph, 'const')
|
||||
Node(graph, 'const_d')['force_shape'] = np.array([2, 5, 7], dtype=np.int64)
|
||||
node.out_port(0).data.set_value(np.zeros(30))
|
||||
self.assertTrue(np.array_equal(node.out_port(0).data.get_shape(), np.array([2, 5, 7], dtype=np.int64)),
|
||||
"node.out_port(0).data.get_shape()={} != [2, 5, 7]".format(
|
||||
node.out_port(0).data.get_shape()))
|
||||
|
||||
|
@ -23,11 +23,22 @@ from mo.graph.graph import Node, Graph
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.utils import refer_to_faq_msg
|
||||
|
||||
"""
|
||||
Packed data of custom types are stored in numpy uint8 data type.
|
||||
To distinguish true uint8 and custom data we introduce this class not to store,
|
||||
but to have unique data type in SUPPORTED_DATA_TYPES map
|
||||
"""
|
||||
|
||||
|
||||
class packed_U1(np.generic):
|
||||
# packed U1 and U8 types of data are stored in numpy uint8 data type
|
||||
# to distinguish true uint8 and u1 data we introduce this class not to store,
|
||||
# but to have unique data type in SUPPORTED_DATA_TYPES map
|
||||
pass
|
||||
|
||||
|
||||
class packed_U4(np.generic):
|
||||
pass
|
||||
|
||||
|
||||
class packed_I4(np.generic):
|
||||
pass
|
||||
|
||||
|
||||
@ -44,7 +55,13 @@ SUPPORTED_DATA_TYPES = {
|
||||
'int32': (np.int32, 'I32', 'i32'),
|
||||
'int64': (np.int64, 'I64', 'i64'),
|
||||
'bool': (np.bool, 'BOOL', 'boolean'),
|
||||
|
||||
# custom types
|
||||
'U1': (packed_U1, 'U1', 'u1'),
|
||||
'int4': (packed_I4, 'I4', 'i4'),
|
||||
'uint4': (packed_U4, 'U4', 'u4'),
|
||||
'I4': (packed_I4, 'I4', 'i4'),
|
||||
'U4': (packed_U4, 'U4', 'u4'),
|
||||
}
|
||||
|
||||
|
||||
|
@ -132,8 +132,13 @@ def eliminate_dead_nodes(graph):
|
||||
# During graph clean-up the operation node is removed and the attribute is lost.
|
||||
# This results in permutation of the Const shape in the IR and wrong inference results.
|
||||
# Here we explicitly save the 'nchw_layout' attribute in the data node to prevent permutation."
|
||||
if node_attrs.get('type', None) == 'Const' and node_attrs.get('nchw_layout', False):
|
||||
Node(graph, node_name).out_node()['nchw_layout'] = True
|
||||
if node_attrs.get('type', None) == 'Const':
|
||||
if node_attrs.get('nchw_layout', False):
|
||||
Node(graph, node_name).out_node()['nchw_layout'] = True
|
||||
if np.all(node_attrs.get('force_shape', False)):
|
||||
Node(graph, node_name).out_node()['force_shape'] = node_attrs['force_shape']
|
||||
if node_attrs.get('force_type', False):
|
||||
Node(graph, node_name).out_node()['force_type'] = node_attrs['force_type']
|
||||
|
||||
if not node_attrs['is_output_reachable'] or \
|
||||
(node_attrs['is_const_producer'] and (not node_attrs['is_undead'] or
|
||||
|
@ -310,6 +310,8 @@ class IREngine(object):
|
||||
'I8': (1, np.int8),
|
||||
'U8': (1, np.uint8),
|
||||
'U1': (1, np.uint8),
|
||||
'U4': (1, np.uint8),
|
||||
'I4': (1, np.uint8),
|
||||
'BOOL': (1, np.bool),
|
||||
'BIN': (1, np.uint8),
|
||||
}
|
||||
|
@ -169,7 +169,14 @@ def propagate_const_values(op: Node):
|
||||
|
||||
op['shape'] = out_data_node.shape
|
||||
# Reshape data node value for correct shape
|
||||
op['value'] = np.reshape(value, op.shape)
|
||||
if op['element_type'] in ['u4', 'i4']:
|
||||
# Packed data types are custom from numpy perspective.
|
||||
# Shape from the IR is incompatible with numpy value we store.
|
||||
op['value'] = value
|
||||
op['force_type'] = op['element_type'].upper()
|
||||
op['force_shape'] = op.shape.copy()
|
||||
else:
|
||||
op['value'] = np.reshape(value, op.shape)
|
||||
|
||||
|
||||
def groupconv_to_conv(op: Node):
|
||||
|
Loading…
Reference in New Issue
Block a user