Files
openvino/docs/optimization_guide/nncf/code/qat_torch.py
Ilya Churaev 0c9abf43a9 Updated copyright headers (#15124)
* Updated copyright headers

* Revert "Fixed linker warnings in docs snippets on Windows (#15119)"

This reverts commit 372699ec49.
2023-01-16 11:02:17 +04:00

63 lines
2.0 KiB
Python

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [imports]
import torch
import nncf # Important - should be imported right after torch
from nncf import NNCFConfig
from nncf.torch import create_compressed_model, register_default_init_args
#! [imports]
#! [nncf_congig]
nncf_config_dict = {
"input_info": {"sample_size": [1, 3, 224, 224]}, # input shape required for model tracing
"compression": {
"algorithm": "quantization", # 8-bit quantization with default settings
},
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_loader) # train_loader is an instance of torch.utils.data.DataLoader
#! [nncf_congig]
#! [wrap_model]
model = TorchModel() # instance of torch.nn.Module
compression_ctrl, model = create_compressed_model(model, nncf_config)
#! [wrap_model]
#! [distributed]
compression_ctrl.distributed() # call it before the training loop
#! [distributed]
#! [tune_model]
... # fine-tuning preparations, e.g. dataset, loss, optimizer setup, etc.
# tune quantized model for 5 epochs as the baseline
for epoch in range(0, 5):
compression_ctrl.scheduler.epoch_step() # Epoch control API
for i, data in enumerate(train_loader):
compression_ctrl.scheduler.step() # Training iteration control API
... # training loop body
#! [tune_model]
#! [export]
compression_ctrl.export_model("compressed_model.onnx")
#! [export]
#! [save_checkpoint]
checkpoint = {
'state_dict': model.state_dict(),
'compression_state': compression_ctrl.get_compression_state(),
... # the rest of the user-defined objects to save
}
torch.save(checkpoint, path_to_checkpoint)
#! [save_checkpoint]
#! [load_checkpoint]
resuming_checkpoint = torch.load(path_to_checkpoint)
compression_state = resuming_checkpoint['compression_state']
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state=compression_state)
state_dict = resuming_checkpoint['state_dict']
model.load_state_dict(state_dict)
#! [load_checkpoint]