[Documentation]: Added description for NNCF PTQ (#14437)

This commit is contained in:
Alexander Kozlov 2022-12-15 17:10:40 +04:00 committed by GitHub
parent ece0341377
commit bc685ac8a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 396 additions and 11 deletions

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271ec8f099a2b9c617a374934596519d228e67967e6e1d8cebbe05de5d080d3b
size 45899

View File

@ -6,7 +6,7 @@
:maxdepth: 1 :maxdepth: 1
:hidden: :hidden:
pot_introduction ptq_introduction
tmo_introduction tmo_introduction
(Experimental) Protecting Model <pot_ranger_README> (Experimental) Protecting Model <pot_ranger_README>
@ -19,9 +19,9 @@
- :ref:`Model Optimizer <openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide>` implements most of the optimization parameters to a model by default. Yet, you are free to configure mean/scale values, batch size, RGB vs BGR input channels, and other parameters to speed up preprocess of a model (:ref:`Embedding Preprocessing Computation <openvino_docs_MO_DG_Additional_Optimization_Use_Cases>`). - :ref:`Model Optimizer <openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide>` implements most of the optimization parameters to a model by default. Yet, you are free to configure mean/scale values, batch size, RGB vs BGR input channels, and other parameters to speed up preprocess of a model (:ref:`Embedding Preprocessing Computation <openvino_docs_MO_DG_Additional_Optimization_Use_Cases>`).
- :ref:`Post-training Optimization w/ POT <pot_introduction>` is designed to optimize inference of deep learning models by applying post-training methods that do not require model retraining or fine-tuning, for example, post-training 8-bit quantization. - :ref:`Post-training Quantization` is designed to optimize inference of deep learning models by applying post-training methods that do not require model retraining or fine-tuning, for example, post-training 8-bit integer quantization.
- :ref:`Training-time Optimization w/ NNCF <tmo_introduction>`, a suite of advanced methods for training-time model optimization within the DL framework, such as PyTorch and TensorFlow 2.x. It supports methods, like Quantization-aware Training and Filter Pruning. NNCF-optimized models can be inferred with OpenVINO using all the available workflows. - :ref:`Training-time Optimization`, a suite of advanced methods for training-time model optimization within the DL framework, such as PyTorch and TensorFlow 2.x. It supports methods, like Quantization-aware Training and Filter Pruning. NNCF-optimized models can be inferred with OpenVINO using all the available workflows.
@endsphinxdirective @endsphinxdirective

View File

@ -0,0 +1,135 @@
# Basic Quantization Flow {#basic_qauntization_flow}
## Introduction
The basic quantization flow is the simplest way to apply 8-bit quantization to the model. It is available for models in the following frameworks: PyTorch, TensorFlow 2.x, ONNX, and OpenVINO. The basic quantization flow is based on the following steps:
* Set up an environment and install dependencies.
* Prepare the **calibration dataset** that is used to estimate quantization parameters of the activations within the model.
* Call the quantization API to apply 8-bit quantization to the model.
## Set up an Environment
It is recommended to set up a separate Python environment for quantization with NNCF. To do this, run the following command:
```bash
python3 -m venv nncf_ptq_env
```
Install all the packages required to instantiate the model object, for example, DL framework. After that, install NNCF on top of the environment:
```bash
pip install nncf
```
## Prepare a Calibration Dataset
At this step, create an instance of the `nncf.Dataset` class that represents the calibration dataset. The `nncf.Dataset` class can be a wrapper over the framework dataset object that is used for model training or validation. The class constructor receives the dataset object and the transformation function. For example, if you use PyTorch, you can pass an instance of the `torch.utils.data.DataLoader` object.
The transformation function is a function that takes a sample from the dataset and returns data that can be passed to the model for inference. For example, this function can take a tuple of a data tensor and labels tensor, and return the former while ignoring the latter. The transformation function is used to avoid modifying the dataset code to make it compatible with the quantization API. The function is applied to each sample from the dataset before passing it to the model for inference. The following code snippet shows how to create an instance of the `nncf.Dataset` class:
@sphinxtabset
@sphinxtab{PyTorch}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_torch.py dataset
@endsphinxtab
@sphinxtab{ONNX}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_onnx.py dataset
@endsphinxtab
@sphinxtab{OpenVINO}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_openvino.py dataset
@endsphinxtab
@sphinxtab{TensorFlow}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_tensorflow.py dataset
@endsphinxtab
@endsphinxtabset
If there is no framework dataset object, you can create your own entity that implements the `Iterable` interface in Python and returns data samples feasible for inference. In this case, a transformation function is not required.
## Run a Quantized Model
Once the dataset is ready and the model object is instantiated, you can apply 8-bit quantization to it:
@sphinxtabset
@sphinxtab{PyTorch}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_torch.py quantization
@endsphinxtab
@sphinxtab{ONNX}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_torch.py quantization
@endsphinxtab
@sphinxtab{OpenVINO}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_torch.py quantization
@endsphinxtab
@sphinxtab{TensorFlow}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_tensorflow.py quantization
@endsphinxtab
@endsphinxtabset
> **NOTE**: The `model` is an instance of the `torch.nn.Module` class for PyTorch, `onnx.ModelProto` for ONNX, and `openvino.runtime.Model` for OpenVINO.
After that the model can be exported into th OpenVINO Intermediate Representation if needed and run faster with OpenVINO.
## Tune quantization parameters
`nncf.quantize()` function has several parameters that allow to tune quantization process to get more accurate model. Below is the list of parameters and their description:
* `model_type` - used to specify quantization scheme required for specific type of the model. For example, **Transformer** models (BERT, distillBERT, etc.) require a special quantization scheme to preserve accuracy after quantization.
```python
nncf.quantize(model, dataset, model_type=nncf.ModelType.Transformer)
```
* `preset` - defines quantization scheme for the model. Two types of presets are available:
* `PERFORMANCE` (default) - defines symmetric quantization of weigths and activations
* `MIXED` - weights are quantized with symmetric quantization and the activations are quantized with asymmetric quantization. This preset is recommended for models with non-ReLU and asymmetric activation funstions, e.g. ELU, PReLU, GELU, etc.
```python
nncf.quantize(model, dataset, preset=nncf.Preset.MIXED)
```
* `fast_bias_correction` - enables more accurate bias (error) correction algorithm that can be used to improve accuracy of the model. This parameter is available only for OpenVINO representation. `True` is used by default.
```python
nncf.quantize(model, dataset, fast_bias_correction=False)
```
* `subset_size` - defines the number of samples from the calibration dataset that will be used to estimate quantization parameters of activations. The default value is 300.
```python
nncf.quantize(model, dataset, subset_size=1000)
```
* `ignored_scope` - this parameter can be used to exclude some layers from quantization process. For example, if you want to exclude the last layer of the model from quantization. Below are some examples of how to use this parameter:
* Exclude by layer name:
```python
names = ['layer_1', 'layer_2', 'layer_3']
nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(names=names))
```
* Exclude by layer type:
```python
types = ['Conv2d', 'Linear']
nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(types=types))
```
* Exclude by regular expression:
```python
regex = '.*layer_.*'
nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(patterns=regex))
```
If the accuracy of the quantized model is not satisfactory, you can try to use the [Quantization with accuracy control](@ref quantization_w_accuracy_control) flow.
## See also
* [Example of basic quantization flow in PyTorch](https://github.com/openvinotoolkit/nncf/tree/develop/examples/post_training_quantization/torch/mobilenet_v2)

View File

@ -0,0 +1,49 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [dataset]
import nncf
import torch
calibration_loader = torch.utils.data.DataLoader(...)
def transform_fn(data_item):
images, _ = data_item
return images
calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
validation_dataset = nncf.Dataset(calibration_loader, transform_fn)
#! [dataset]
#! [validation]
import numpy as np
import torch
import openvino
from sklearn.metrics import accuracy_score
def validate(model: openvino.runtime.CompiledModel,
validation_loader: torch.utils.data.DataLoader) -> float:
predictions = []
references = []
output = model.outputs[0]
for images, target in validation_loader:
pred = model(images)[output]
predictions.append(np.argmax(pred, axis=1))
references.append(target)
predictions = np.concatenate(predictions, axis=0)
references = np.concatenate(references, axis=0)
return accuracy_score(predictions, references)
#! [validation]
#! [quantization]
model = ... # openvino.runtime.Model object
quantized_model = nncf.quantize_with_accuracy_control(model,
calibration_dataset=calibration_dataset,
validation_dataset=validation_dataset,
validation_fn=validate,
max_drop=0.01)
#! [quantization]

View File

@ -0,0 +1,22 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [dataset]
import nncf
import torch
calibration_loader = torch.utils.data.DataLoader(...)
def transform_fn(data_item):
images, _ = data_item
return {input_name: images.numpy()} # input_name should be taken from the model,
# e.g. model.graph.input[0].name
calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
#! [dataset]
#! [quantization]
model = ... # onnx.ModelProto object
quantized_model = nncf.quantize(model, calibration_dataset)
#! [quantization]

View File

@ -0,0 +1,21 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [dataset]
import nncf
import torch
calibration_loader = torch.utils.data.DataLoader(...)
def transform_fn(data_item):
images, _ = data_item
return images.numpy()
calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
#! [dataset]
#! [quantization]
model = ... # openvino.runtime.Model object
quantized_model = nncf.quantize(model, calibration_dataset)
#! [quantization]

View File

@ -0,0 +1,21 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [dataset]
import nncf
import tensorflow_datasets as tfds
calibration_loader = tfds.load(...)
def transform_fn(data_item):
images, _ = data_item
return images
calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
#! [dataset]
#! [quantization]
model = ... # tensorflow.Module object
quantized_model = nncf.quantize(model, calibration_dataset)
#! [quantization]

View File

@ -0,0 +1,21 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [dataset]
import nncf
import torch
calibration_loader = torch.utils.data.DataLoader(...)
def transform_fn(data_item):
images, _ = data_item
return images
calibration_dataset = nncf.Dataset(calibration_loader, transform_fn)
#! [dataset]
#! [quantization]
model = ... # torch.nn.Module object
quantized_model = nncf.quantize(model, calibration_dataset)
#! [quantization]

View File

@ -0,0 +1,22 @@
# Post-training Quantization w/ NNCF (new) {#nncf_ptq_introduction}
@sphinxdirective
.. toctree::
:maxdepth: 1
:hidden:
basic_qauntization_flow
quantization_w_accuracy_control
@endsphinxdirective
Neural Network Compression Framework (NNCF) provides a new post-training quantization API available in Python that is aimed at reusing the code for model training or validation that is usually available with the model in the source framework, for example, PyTorch* or TensroFlow*. The API is cross-framework and currently supports models representing in the following frameworks: PyTorch, TensorFlow 2.x, ONNX, and OpenVINO.
This API has two main capabilities to apply 8-bit post-training quantization:
* [Basic quantization](@ref basic_qauntization_flow) - the simplest quantization flow that allows to apply 8-bit integer quantization to the model.
* [Quantization with accuracy control](@ref quantization_w_accuracy_control) - the most advanced quantization flow that allows to apply 8-bit quantization to the model with accuracy control.
## See also
* [NNCF GitHub](https://github.com/openvinotoolkit/nncf)
* [Optimizing Models at Training Time](@ref tmo_introduction)

View File

@ -0,0 +1,66 @@
# Quantizing with accuracy control {#quantization_w_accuracy_control}
## Introduction
This is the advanced quantization flow that allows to apply 8-bit quantization to the model with control of accuracy metric. This is achieved by keeping the most impactful operations within the model in the original precision. The flow is based on the [Basic 8-bit quantization](@ref basic_qauntization_flow) and has the following differences:
* Besided the calibration dataset, a **validation dataset** is required to compute accuracy metric. They can refer to the same data in the simplest case.
* **Validation function**, used to compute accuracy metric is required. It can be a function that is already available in the source framework or a custom function.
* Since accuracy validation is run several times during the quantization process, quantization with accuracy control can take more time than the [Basic 8-bit quantization](@ref basic_qauntization_flow) flow.
* The resulted model can provide smaller performance improvement than the [Basic 8-bit quantization](@ref basic_qauntization_flow) flow because some of the operations are kept in the original precision.
> **NOTE**: Currently, this flow is available only for models in OpenVINO representation.
The steps for the quantizatation with accuracy control are described below.
## Prepare datasets
This step is similar to the [Basic 8-bit quantization](@ref basic_qauntization_flow) flow. The only difference is that two datasets, calibration and validation, are required.
@sphinxtabset
@sphinxtab{OpenVINO}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_aa_openvino.py dataset
@endsphinxtab
@endsphinxtabset
## Prepare validation function
Validation funtion receives `openvino.runtime.CompiledModel` object and
validation dataset and returns accuracy metric value. The following code snippet shows an example of validation function for OpenVINO model:
@sphinxtabset
@sphinxtab{OpenVINO}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_aa_openvino.py validation
@endsphinxtab
@endsphinxtabset
## Run quantization with accuracy control
Now, you can run quantization with accuracy control. The following code snippet shows an example of quantization with accuracy control for OpenVINO model:
@sphinxtabset
@sphinxtab{OpenVINO}
@snippet docs/optimization_guide/nncf/ptq/code/ptq_aa_openvino.py quantization
@endsphinxtab
@endsphinxtabset
`max_drop` defines the accuracy drop threshold. The quantization process stops when the degradation of accuracy metric on the validation dataset is less than the `max_drop`.
`nncf.quantize_with_accuracy_control()` API supports all the parameters of `nncf.quantize()` API. For example, you can use `nncf.quantize_with_accuracy_control()` to quantize a model with a custom configuration.
## See also
* [Optimizing Models at Training Time](@ref tmo_introduction)

View File

@ -0,0 +1,31 @@
# Quantizing Models Post-training {#ptq_introduction}
@sphinxdirective
.. toctree::
:maxdepth: 1
:hidden:
pot_introduction
nncf_ptq_introduction
@endsphinxdirective
Post-training model optimization is the process of applying special methods that transform the model into a more hardware-friendly representation without retraining or fine-tuning. The most popular and widely-spread method here is 8-bit post-training quantization because it is:
* It is easy-to-use.
* It does not hurt accuracy a lot.
* It provides significant performance improvement.
* It suites many hardware available in stock since most of them support 8-bit computation natively.
8-bit integer quantization lowers the precision of weights and activations to 8 bits, which leads to almost 4x reduction in the model footprint and significant improvements in inference speed, mostly due to lower throughput required for the inference. This lowering step is done offline, before the actual inference, so that the model gets transformed into the quantized representation. The process does not require a training dataset or a training pipeline in the source DL framework.
![](../img/quantization_picture.png)
To apply post-training methods in OpenVINO&trade;, you need:
* A floating-point precision model, FP32 or FP16, converted into the OpenVINO&trade; Intermediate Representation (IR) format that can be run on CPU.
* A representative calibration dataset, representing a use case scenario, for example, of 300 samples.
* In case of accuracy constraints, a validation dataset and accuracy metrics should be available.
Currently, OpenVINO provides two workflows with post-training quantization capabilities:
* [Post-training Quantization with POT](@ref pot_introduction) - works with models in OpenVINO Intermediate Representation (IR) only.
* [Post-training Quantization with NNCF](@ref nncf_ptq_introduction) - cross-framework solution for model optimization that provides a new simple API for post-training quantization.

View File

@ -1,4 +1,4 @@
# Quantizing Models Post-training {#pot_introduction} # Post-training Quantization w/ POT {#pot_introduction}
@sphinxdirective @sphinxdirective
@ -16,14 +16,8 @@
@endsphinxdirective @endsphinxdirective
## Introduction
Post-training quantization is a model compression technique where the values in a neural network are converted from a 32-bit or 16-bit format to an 8-bit integer format after the network has been fine-tuned on a training dataset. This helps to reduce the models latency by taking advantage of computationally efficient 8-bit integer arithmetic. It also reduces the model's size and memory footprint.
Post-training quantization is easy to implement and is a quick way to boost model performance. It only requires a representative dataset, and it can be performed using the Post-training Optimization Tool (POT) in OpenVINO. POT is distributed as part of the [OpenVINO Development Tools](@ref openvino_docs_install_guides_install_dev_tools) package. To apply post-training quantization with POT, you need: For the needs of post-training optimization, OpenVINO&trade; provides a **Post-training Optimization Tool (POT)** which supports the **uniform integer quantization** method. This method allows moving from floating-point precision to integer precision (for example, 8-bit) for weights and activations during the inference time. It helps to reduce the model size, memory footprint and latency, as well as improve the computational efficiency, using integer arithmetic. During the quantization process the model undergoes the transformation process when additional operations, that contain quantization information, are inserted into the model. The actual transition to integer arithmetic happens at model inference.
* A floating-point precision model, FP32 or FP16, converted into the OpenVINO Intermediate Representation (IR) format.
* A representative dataset (annotated or unannotated) of around 300 samples that depict typical use cases or scenarios.
* (Optional) An annotated validation dataset that can be used for checking the models accuracy.
The post-training quantization algorithm takes samples from the representative dataset, inputs them into the network, and calibrates the network based on the resulting weights and activation values. Once calibration is complete, values in the network are converted to 8-bit integer format. The post-training quantization algorithm takes samples from the representative dataset, inputs them into the network, and calibrates the network based on the resulting weights and activation values. Once calibration is complete, values in the network are converted to 8-bit integer format.