diff --git a/tools/pot/openvino/tools/pot/algorithms/quantization/bias_correction/algorithm.py b/tools/pot/openvino/tools/pot/algorithms/quantization/bias_correction/algorithm.py index 7627c401c87..3f19883a3d6 100644 --- a/tools/pot/openvino/tools/pot/algorithms/quantization/bias_correction/algorithm.py +++ b/tools/pot/openvino/tools/pot/algorithms/quantization/bias_correction/algorithm.py @@ -18,7 +18,8 @@ from ....graph.transformer import GraphTransformer from ....samplers.creator import create_sampler from ....statistics.functions import activations as asf from ....statistics.functions import aggregation as agf -from ....statistics.statistics import TensorStatisticAxis +from ....statistics.statistics import TensorStatisticAxis, TensorStatistic +from ..utils import get_input_shape_for_bias from ....utils.launcher import IELauncher from ....utils.logger import get_logger @@ -355,7 +356,7 @@ class BiasCorrection(Algorithm): q_outputs.append(asf.mean_per_channel_axis(q_output[add_name], add_name, channel=self._channel_axis)) q_output = agf.mean(q_outputs) - add_out_shape = nu.get_input_shape_for_bias(params['node_bias_add']) + add_out_shape = get_input_shape_for_bias(self._fp32_statistics, params['node_bias_add'].fullname) axis_channel = self.get_channel_axis(add_name) bias_shift_value = fp32_output - q_output bias_shape = np.ones(len(add_out_shape), dtype=np.int) @@ -430,6 +431,8 @@ class BiasCorrection(Algorithm): type='mean', inplace_statistics=self.config['inplace_statistics'], channel=self._channel_axis)} + statistics_layout[add_node_name]["shape"] = TensorStatistic(func=lambda x, **kwargs: x.shape, + shape_for_inference=True) layers_mapping = fqut.create_renamed_layers_mapping(quantized_model, statistics_layout) self._stats_collector.register(self.name, statistics_layout, self._sampler, layers_mapping) diff --git a/tools/pot/openvino/tools/pot/algorithms/quantization/fast_bias_correction/algorithm.py b/tools/pot/openvino/tools/pot/algorithms/quantization/fast_bias_correction/algorithm.py index 9dfa5553faf..077a63b5602 100755 --- a/tools/pot/openvino/tools/pot/algorithms/quantization/fast_bias_correction/algorithm.py +++ b/tools/pot/openvino/tools/pot/algorithms/quantization/fast_bias_correction/algorithm.py @@ -14,7 +14,8 @@ from ....graph.special_operations import OPERATIONS_WITH_BIAS, OPERATIONS_CHANNE from ....samplers.creator import create_sampler from ....statistics.functions import activations as asf from ....statistics.functions import aggregation as agf -from ....statistics.statistics import TensorStatisticAxis +from ....statistics.statistics import TensorStatisticAxis, TensorStatistic +from ..utils import get_input_shape_for_bias from ....utils.launcher import IELauncher from ....utils.logger import get_logger @@ -52,7 +53,6 @@ class FastBiasCorrection(Algorithm): mu.nx_type_infer(model) activations_statistics = self._stats_collector.get_statistics_for_algorithm(self.name) nodes_with_bias = mu.get_nodes_by_type(model, [op['type'] for op in OPERATIONS_WITH_BIAS]) - inputs_shape_nodes_with_bias = self.get_inputs_shape(nodes_with_bias) self.find_channel_axis(model) launcher = IELauncher() @@ -75,7 +75,8 @@ class FastBiasCorrection(Algorithm): input_node = nu.get_node_input(input_node, 0) quantized_node = nu.get_node_input(op_node, 0) - input_shape = inputs_shape_nodes_with_bias[input_node.fullname] + input_node_name = get_quantized_input_key(quantized_node) + input_shape = get_input_shape_for_bias(activations_statistics, input_node_name) op_model = mu.build_model_for_node(model, input_node.fullname, input_shape, op_node, remove_bias=True, target_device=self._config['target_device']) @@ -84,7 +85,6 @@ class FastBiasCorrection(Algorithm): bias = nu.get_bias_for_node(op_node) after_biased_conv = nu.get_node_output(bias, 0)[0] - input_node_name = get_quantized_input_key(quantized_node) fp32_inputs = agf.mean(activations_statistics[input_node_name]["mean_per_channel"]) fp32_outputs = agf.mean(activations_statistics[after_biased_conv.fullname]["mean_per_channel"]) @@ -92,7 +92,8 @@ class FastBiasCorrection(Algorithm): launcher, input_node.fullname, input_shape, op_model, fp32_inputs, fp32_outputs) current_bias_value = nu.get_node_value(bias_node) # Reshaped since bias are broadcasted - add_out_shape = nu.get_input_shape_for_bias(after_biased_conv) + after_biased_conv_node_name = get_quantized_input_key(after_biased_conv) + add_out_shape = get_input_shape_for_bias(activations_statistics, after_biased_conv_node_name) bias_shape = np.ones(len(add_out_shape), dtype=np.int) axis_channel = self.get_channel_axis(input_node_name) bias_shape[axis_channel] = add_out_shape[axis_channel] @@ -152,6 +153,14 @@ class FastBiasCorrection(Algorithm): "mean_per_channel": TensorStatisticAxis(inplace_statistics=inplace_statistics, granularity='perchannel', type='mean', channel=self._channel_axis)} + inputs_outputs_layout[input_name]["shape"] = TensorStatistic(func=lambda x, **kwargs: x.shape, + shape_for_inference=True) + if nu.get_bias_for_node(op_node): + bias = nu.get_bias_for_node(op_node) + after_biased_conv = nu.get_node_output(bias, 0)[0] + after_biased_conv_name = get_quantized_input_key(after_biased_conv) + inputs_outputs_layout[after_biased_conv_name] = \ + {"shape": TensorStatistic(func=lambda x, **kwargs: x.shape, shape_for_inference=True)} return inputs_outputs_layout @@ -203,23 +212,3 @@ class FastBiasCorrection(Algorithm): if isinstance(node, tuple): return self._channel_axis[node[0]] return self._channel_axis[node] - - def get_inputs_shape(self, nodes_with_bias): - sampler = create_sampler(self._engine, 1, False, 0) - calculate_input_shape = {} - for op_node in nodes_with_bias: - input_node = nu.get_node_input(op_node, 0) - if input_node.type == 'FakeQuantize': - input_node = nu.get_node_input(input_node, 0) - calculate_input_shape[input_node.fullname] = {'shape_node': lambda x: x.shape} - calculate_metrics = self._engine.calculate_metrics - self._engine.calculate_metrics = False - self._engine.inference_for_shape = True - _, inputs_shape = self._engine.predict(calculate_input_shape, sampler) - self._engine.inference_for_shape = False - self._engine.calculate_metrics = calculate_metrics - for node_name, shape_node in inputs_shape.items(): - inputs_shape[node_name] = shape_node['shape_node'][0] - if len(inputs_shape[node_name]) > 1: - inputs_shape[node_name] = (1, *inputs_shape[node_name][1:]) - return inputs_shape diff --git a/tools/pot/openvino/tools/pot/algorithms/quantization/utils.py b/tools/pot/openvino/tools/pot/algorithms/quantization/utils.py index 94bf92e6009..f8bee54511a 100644 --- a/tools/pot/openvino/tools/pot/algorithms/quantization/utils.py +++ b/tools/pot/openvino/tools/pot/algorithms/quantization/utils.py @@ -4,6 +4,8 @@ from copy import deepcopy from pathlib import Path +from scipy.stats import mode + from .range_estimator import get_range_estimator_config from ...api.engine import Engine from ...configs.hardware_config import HardwareConfig @@ -338,6 +340,13 @@ def get_stat_name_by_config(config, stat_type): return '_'.join(name_list) +def get_input_shape_for_bias(activations_statistics, input_node_name): + input_shape = mode(activations_statistics[input_node_name]['shape'])[0][0] + if len(input_shape) > 1: + input_shape[0] = 1 + return input_shape + + def get_ignored_operations(model): operation = {"transformer": [{"type": "Add"}, {"type": "Power"}, {"type": "Squeeze"}, {"type": "Multiply"}, diff --git a/tools/pot/openvino/tools/pot/api/engine.py b/tools/pot/openvino/tools/pot/api/engine.py index 3acbfa31d2e..c841b5453c5 100644 --- a/tools/pot/openvino/tools/pot/api/engine.py +++ b/tools/pot/openvino/tools/pot/api/engine.py @@ -24,8 +24,6 @@ class Engine(ABC): self._statistic_graph_builder = StatisticGraphBuilder() self._stat_requests_number = self.config.get('stat_requests_number', None) self._eval_requests_number = self.config.get('eval_requests_number', None) - self.inference_for_shape = False - self.calculate_metrics = True def set_model(self, model): """ Set/reset model to instance of engine class diff --git a/tools/pot/openvino/tools/pot/engines/ac_engine.py b/tools/pot/openvino/tools/pot/engines/ac_engine.py index c4e11af2f1a..81ac1e7a81e 100644 --- a/tools/pot/openvino/tools/pot/engines/ac_engine.py +++ b/tools/pot/openvino/tools/pot/engines/ac_engine.py @@ -248,7 +248,7 @@ class ACEngine(Engine): if not stats_layout: return dataset_index = kwargs['dataset_indices'][0] - append_stats(self._accumulated_layer_stats, stats_layout, value, dataset_index, self.inference_for_shape) + append_stats(self._accumulated_layer_stats, stats_layout, value, dataset_index) @staticmethod def _set_requests_number(params, requests_number): diff --git a/tools/pot/openvino/tools/pot/engines/ie_engine.py b/tools/pot/openvino/tools/pot/engines/ie_engine.py index 26cdaf8d567..fd01482722f 100644 --- a/tools/pot/openvino/tools/pot/engines/ie_engine.py +++ b/tools/pot/openvino/tools/pot/engines/ie_engine.py @@ -202,7 +202,7 @@ class IEEngine(Engine): :param annotations: list of annotations [(img_id, annotation)] """ dataset_index = annotations[0][0] if annotations is not None and annotations[0][0] else 0 - append_stats(self._accumulated_layer_stats, stats_layout, outputs, dataset_index, self.inference_for_shape) + append_stats(self._accumulated_layer_stats, stats_layout, outputs, dataset_index) def _update_metrics(self, output, annotations, need_metrics_per_sample=False): """ Updates metrics. diff --git a/tools/pot/openvino/tools/pot/engines/simplified_engine.py b/tools/pot/openvino/tools/pot/engines/simplified_engine.py index 4714206d6ec..d6fa22ccc90 100644 --- a/tools/pot/openvino/tools/pot/engines/simplified_engine.py +++ b/tools/pot/openvino/tools/pot/engines/simplified_engine.py @@ -20,4 +20,4 @@ class SimplifiedEngine(IEEngine): batch_annotations, batch_meta, need_metrics_per_sample): # Collect statistics if stats_layout: - append_stats(self._accumulated_layer_stats, stats_layout, predictions, 0, self.inference_for_shape) + append_stats(self._accumulated_layer_stats, stats_layout, predictions, 0) diff --git a/tools/pot/openvino/tools/pot/engines/utils.py b/tools/pot/openvino/tools/pot/engines/utils.py index d6ccdd9153d..73a47970326 100644 --- a/tools/pot/openvino/tools/pot/engines/utils.py +++ b/tools/pot/openvino/tools/pot/engines/utils.py @@ -13,10 +13,10 @@ from ..utils.utils import convert_output_key logger = get_logger(__name__) -def append_stats(accumulated_layer_stats, stats_layout, value, dataset_index, inference_for_shape): +def append_stats(accumulated_layer_stats, stats_layout, value, dataset_index): inplace_stats_mapping = get_inplace_stats_mapping(stats_layout) if isinstance(value, list): - value = parse_sequential_stats(value, stats_layout, inference_for_shape) + value = parse_sequential_stats(value, stats_layout) else: value = process_raw_output(value) for layer, stats in stats_layout.items(): @@ -29,7 +29,7 @@ def append_stats(accumulated_layer_stats, stats_layout, value, dataset_index, in (dataset_index, compute_statistic(stat_fn, value, layer_stat_name))) -def parse_sequential_stats(value_sequential, stats_layout, inference_for_shape): +def parse_sequential_stats(value_sequential, stats_layout): stat_names_by_layer, old_names_mapping = get_per_layer_stat_mapping(stats_layout) activation_seq = defaultdict(lambda: []) for value in value_sequential: @@ -40,7 +40,8 @@ def parse_sequential_stats(value_sequential, stats_layout, inference_for_shape): for layer, act_seq in activation_seq.items(): seq_len = len(act_seq[0].shape) - if inference_for_shape: + if isinstance(stat_names_by_layer[layer], Statistic) and \ + stat_names_by_layer[layer].kwargs.get('shape_for_inference', False): activation_seq[layer] = act_seq[0] continue if not isinstance(stat_names_by_layer[layer], Statistic) or \ diff --git a/tools/pot/openvino/tools/pot/graph/node_utils.py b/tools/pot/openvino/tools/pot/graph/node_utils.py index 9ef9604c9d5..a4bd28f6f22 100644 --- a/tools/pot/openvino/tools/pot/graph/node_utils.py +++ b/tools/pot/openvino/tools/pot/graph/node_utils.py @@ -209,18 +209,6 @@ def node_with_quantized_weights(node): return False -def get_input_shape_for_bias(op_node): - """ - Generate input shape for bias node - :param op_node: output node for bias - :return: new shape - """ - input_shape = get_input_shape(op_node, 0).copy() - if len(input_shape) > 1: - input_shape[0] = 1 - return input_shape - - def get_input_data_value(node: Node, port: int): """ Return value of data node for needed node at port diff --git a/tools/pot/tests/test_engines.py b/tools/pot/tests/test_engines.py index 807ed5003db..6402c7f290b 100644 --- a/tools/pot/tests/test_engines.py +++ b/tools/pot/tests/test_engines.py @@ -44,8 +44,7 @@ def run_append_stats_test(engine): fc_layer_mock = create_ng_mock(['fc_layer']) value = {conv_layer_mock: sample_tensor, fc_layer_mock: sample_tensor} ref_value = {'conv_layer': sample_tensor, 'fc_layer': sample_tensor} - append_stats(engine._accumulated_layer_stats, stats_layout, value, - dataset_index=0, inference_for_shape=False) + append_stats(engine._accumulated_layer_stats, stats_layout, value, dataset_index=0) for layer, accumulated_value in engine._accumulated_layer_stats.items(): assert np.array_equal(accumulated_value[stat_name][0][1], ref_value[layer]) @@ -58,8 +57,7 @@ def run_append_stats_test(engine): {'conv_layer': sample_tensor, 'fc_layer': sample_tensor}, {'conv_layer': sample_tensor, 'fc_layer': sample_tensor}, ] - append_stats(engine._accumulated_layer_stats, stats_layout, value, - dataset_index=0, inference_for_shape=False) + append_stats(engine._accumulated_layer_stats, stats_layout, value, dataset_index=0) for layer, accumulated_value in engine._accumulated_layer_stats.items(): assert np.array_equal( accumulated_value[stat_name][0][1][:, 0], ref_value[0][layer] diff --git a/tools/pot/tests/test_sanity.py b/tools/pot/tests/test_sanity.py index 05bc3f96acd..cb8d6369a11 100644 --- a/tools/pot/tests/test_sanity.py +++ b/tools/pot/tests/test_sanity.py @@ -28,7 +28,7 @@ TEST_MODELS = [ {}, 'CPU'), ('mobilenet-v2-pytorch', 'pytorch', 'DefaultQuantization', 'mixed', 300, {'accuracy@top1': 0.731, - 'accuracy@top5': 0.906}, + 'accuracy@top5': 0.908}, {}, 'CPU'), ('mobilenet-v1-1.0-224-tf', 'tf', 'DefaultQuantization', 'performance', 100, {'accuracy@top1': 0.728, @@ -39,10 +39,10 @@ TEST_MODELS = [ 'accuracy@top5': 0.911}, {}, 'CPU'), - ('mobilenet-ssd', 'caffe', 'AccuracyAwareQuantization', 'performance', 300, {'map': 0.674}, + ('mobilenet-ssd', 'caffe', 'AccuracyAwareQuantization', 'performance', 300, {'map': 0.6801}, {'metric_subset_ratio': 1.0, 'max_iter_num': 1, 'metrics': [{'name': 'map', 'baseline_value': 0.669}]}, 'CPU'), - ('mobilenet-ssd', 'caffe', 'AccuracyAwareQuantization', 'performance', 300, {'map': 0.674}, + ('mobilenet-ssd', 'caffe', 'AccuracyAwareQuantization', 'performance', 300, {'map': 0.6801}, {'metric_subset_ratio': 1.0, 'max_iter_num': 1, 'tune_hyperparams': True, 'metrics': [{'name': 'map', 'baseline_value': 0.669}]}, 'CPU'), @@ -51,9 +51,9 @@ TEST_MODELS = [ # {'drop_type': 'relative', 'max_iter_num': 1, 'accuracy_drop': 0.005, 'metrics': [ # {'name': 'accuracy@top1', 'baseline_value': 0.431}]}, 'GNA'), - ('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 1, {'recall': 0.76, 'map': 0.6844}, {}, 'CPU'), + ('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 1, {'recall': 0.76, 'map': 0.6618}, {}, 'CPU'), - ('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 2, {'recall': 0.76, 'map': 0.6638}, + ('mtcnn', 'caffe', 'DefaultQuantization', 'performance', 2, {'recall': 0.68, 'map': 0.4406}, {'use_fast_bias': False}, 'CPU'), ('octave-resnet-26-0.25', 'mxnet', 'DefaultQuantization', 'performance', 300, {'accuracy@top1': 0.766, 'accuracy@top5': 0.927}, {'use_fast_bias': False}, 'CPU'), @@ -173,7 +173,7 @@ SIMPLIFIED_TEST_MODELS = [ ('mobilenet-v2-pytorch', 'pytorch', 'DefaultQuantization', 'performance', {'accuracy@top1': 0.701, 'accuracy@top5': 0.91}, []), ('mobilenet-v2-pytorch', 'pytorch', 'DefaultQuantization', 'performance', - {'accuracy@top1': 0.709, 'accuracy@top5': 0.906}, ['--input_shape=[1,3,?,?]']) + {'accuracy@top1': 0.71, 'accuracy@top5': 0.906}, ['--input_shape=[1,3,?,?]']) ] diff --git a/tools/pot/tests/test_statistic_builder.py b/tools/pot/tests/test_statistic_builder.py index f7bc5c0e3c1..44cac44d9c1 100644 --- a/tools/pot/tests/test_statistic_builder.py +++ b/tools/pot/tests/test_statistic_builder.py @@ -32,11 +32,11 @@ TEST_MODELS = [ 'quantile', 'abs_quantile'), ('mobilenetv2_example', 'pytorch', 'symmetric', True, ActivationChannelAlignment, 'mixed', 'perchannel', 1, None, None), - ('squeezenet1_1_example', 'pytorch', 'symmetric', True, FastBiasCorrection, 'mixed', 'perchannel', 0, + ('squeezenet1_1_example', 'pytorch', 'symmetric', True, FastBiasCorrection, 'mixed', 'perchannel', 42, None, None), - ('mobilenetv2_ssd_example', 'pytorch', 'symmetric', True, FastBiasCorrection, 'mixed', 'perchannel', 0, + ('mobilenetv2_ssd_example', 'pytorch', 'symmetric', True, FastBiasCorrection, 'mixed', 'perchannel', 117, None, None), - ('mobilenet_v3_small_example', 'pytorch', 'symmetric', True, BiasCorrection, 'mixed', 'perchannel', 1, + ('mobilenet_v3_small_example', 'pytorch', 'symmetric', True, BiasCorrection, 'mixed', 'perchannel', 53, None, None) ] diff --git a/tools/pot/tests/test_statistics_collector.py b/tools/pot/tests/test_statistics_collector.py index d495dda5356..382bad21b2e 100644 --- a/tools/pot/tests/test_statistics_collector.py +++ b/tools/pot/tests/test_statistics_collector.py @@ -69,12 +69,12 @@ def test_statistics_collector_subsets(tmp_path, models, model_name, model_framew for algo_name, algo_val in local_out.items(): for node_name, node_val in algo_val.items(): for stats_name, stats_val in node_val.items(): - local_out[algo_name][node_name][stats_name] = [v.tolist() for v in stats_val] + local_out[algo_name][node_name][stats_name] = [np.array(v).tolist() for v in stats_val] json.dump(local_out, local_file) for algo_name, algo_val in out.items(): for node_name, node_val in algo_val.items(): for stats_name, stats_val in node_val.items(): - if stats_name == 'batch_mean_param_in': + if stats_name in ['batch_mean_param_in', 'shape']: continue ref_stats_vals = refs[algo_name][node_name][stats_name] for ref_vals, vals in zip(ref_stats_vals, stats_val):