[POT] get_num_levels function (#8393)
* added function for calculating the number of discret levels in the input tensors and tests for this function * added function for calculating the number of discret levels in the input tensors and tests for this function * fixed pylint issues * changed the function for delta estimation from mean to min * added empty delta array processing in get_num_levels func and tests for it
This commit is contained in:
parent
dd5efdca85
commit
ac0582c2d9
@ -450,3 +450,24 @@ def create_renamed_layers_mapping(model, stats_layout):
|
|||||||
name_change_to = node['orig_node_name'] if port_id is None else (node['orig_node_name'], port_id)
|
name_change_to = node['orig_node_name'] if port_id is None else (node['orig_node_name'], port_id)
|
||||||
changed_names_map[layer_name] = name_change_to
|
changed_names_map[layer_name] = name_change_to
|
||||||
return changed_names_map
|
return changed_names_map
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_levels(x: np.ndarray) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the number of discret levels of the values
|
||||||
|
in the input NumPy tensor x
|
||||||
|
:param x: the input tensor
|
||||||
|
:return the number of discret value levels in the input tensor x
|
||||||
|
"""
|
||||||
|
NUM_BINS = 256
|
||||||
|
x = x.flatten()
|
||||||
|
hist, _ = np.histogram(x, NUM_BINS)
|
||||||
|
non_empty_bins = [i for i, v in enumerate(hist) if v > 0]
|
||||||
|
deltas = [non_empty_bins[i]-non_empty_bins[i-1] for i in range(1, len(non_empty_bins))]
|
||||||
|
if deltas == []:
|
||||||
|
return 0
|
||||||
|
d = min(deltas)
|
||||||
|
if d == 1:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
return round(NUM_BINS / d)
|
||||||
|
@ -7,6 +7,8 @@ import pytest
|
|||||||
from openvino.tools.pot.statistics.function_selector import AGGREGATION_FN, ACTIVATIONS_STATS_FN, WEIGHTS_STATS_FN, \
|
from openvino.tools.pot.statistics.function_selector import AGGREGATION_FN, ACTIVATIONS_STATS_FN, WEIGHTS_STATS_FN, \
|
||||||
get_aggregation_function, get_stats_function_for_activations, get_stats_function_for_weights, PERCHANNEL, PERTENSOR
|
get_aggregation_function, get_stats_function_for_activations, get_stats_function_for_weights, PERCHANNEL, PERTENSOR
|
||||||
|
|
||||||
|
from openvino.tools.pot.algorithms.quantization.fake_quantize import get_num_levels
|
||||||
|
|
||||||
INPUT_SHAPES = [(2, 2, 1), (2, 2, 2)]
|
INPUT_SHAPES = [(2, 2, 1), (2, 2, 2)]
|
||||||
AGG_INPUTS = [np.reshape(np.array(range(np.prod(shape)), dtype=np.float32), shape) for shape in INPUT_SHAPES]
|
AGG_INPUTS = [np.reshape(np.array(range(np.prod(shape)), dtype=np.float32), shape) for shape in INPUT_SHAPES]
|
||||||
|
|
||||||
@ -137,14 +139,17 @@ WEIGHTS_CH_STATS_FUNCTIONS = [(name, True) for name in
|
|||||||
WEIGHTS_STATS_FN[PERCHANNEL].registry_dict.keys()]
|
WEIGHTS_STATS_FN[PERCHANNEL].registry_dict.keys()]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
NUM_LEVELS_PARAMS = [
|
||||||
'name, transpose', WEIGHTS_CH_STATS_FUNCTIONS,
|
(np.random.randint, (0, 2, (3, 100, 100)), 1, 1),
|
||||||
ids=['{}_{}'.format(fn[0], fn[1]) for fn in WEIGHTS_CH_STATS_FUNCTIONS])
|
(np.random.randint, (-32, 32, (3, 100, 100)), 64, 1),
|
||||||
def test_weights_transpose_function(name, transpose):
|
(np.random.randint, (-32, 32, (3, 100, 100)), 64, 1/512),
|
||||||
fn = get_stats_function_for_weights(name, PERCHANNEL)
|
(np.random.rand, (3, 100, 100), -1, 1),
|
||||||
if name in ['quantile', 'abs_quantile']:
|
(np.random.randint, (0, 1, (3, 100, 100)), 0, 1)
|
||||||
result = fn(INPUT, q=1e-2, transpose=transpose)
|
]
|
||||||
else:
|
|
||||||
result = fn(INPUT, transpose=transpose)
|
|
||||||
expected = GOLD_VALUES_CH_TRANS_WEIGHT_FUNCTIONS[name]
|
@pytest.mark.parametrize('gen_func,params,expected,coef', NUM_LEVELS_PARAMS)
|
||||||
np.testing.assert_almost_equal(result, expected)
|
def test_get_num_levels_function(gen_func, params, expected, coef):
|
||||||
|
test_1 = gen_func(*params) * coef
|
||||||
|
result = get_num_levels(test_1)
|
||||||
|
assert result == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user