Files
openvino/docs/optimization_guide/nncf/code/qat_tf.py
Ilya Churaev 0c9abf43a9 Updated copyright headers (#15124)
* Updated copyright headers

* Revert "Fixed linker warnings in docs snippets on Windows (#15119)"

This reverts commit 372699ec49.
2023-01-16 11:02:17 +04:00

70 lines
2.5 KiB
Python

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#! [imports]
import tensorflow as tf
from nncf import NNCFConfig
from nncf.tensorflow import create_compressed_model, register_default_init_args
#! [imports]
#! [nncf_congig]
nncf_config_dict = {
"input_info": {"sample_size": [1, 3, 224, 224]}, # input shape required for model tracing
"compression": {
"algorithm": "quantization", # 8-bit quantization with default settings
},
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_dataset, batch_size=1) # train_dataset is an instance of tf.data.Dataset
#! [nncf_congig]
#! [wrap_model]
model = KerasModel() # instance of the tensorflow.keras.Model
compression_ctrl, model = create_compressed_model(model, nncf_config)
#! [wrap_model]
#! [distributed]
compression_ctrl.distributed() # call it before the training
#! [distributed]
#! [tune_model]
... # fine-tuning preparations, e.g. dataset, loss, optimizer setup, etc.
# create compression callbacks to control optimization parameters and dump compression statistics
compression_callbacks = create_compression_callbacks(compression_ctrl, log_dir="./compression_log")
# tune quantized model for 5 epochs the same way as the baseline
model.fit(train_dataset, epochs=5, callbacks=compression_callbacks)
#! [tune_model]
#! [export]
compression_ctrl.export_model("compressed_model.pb") #export to Frozen Graph
#! [export]
#! [save_checkpoint]
from nncf.tensorflow.utils.state import TFCompressionState
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback
checkpoint = tf.train.Checkpoint(model=model,
compression_state=TFCompressionState(compression_ctrl),
... # the rest of the user-defined objects to save
)
callbacks = []
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
...
model.fit(..., callbacks=callbacks)
#! [save_checkpoint]
#! [load_checkpoint]
from nncf.tensorflow.utils.state import TFCompressionStateLoader
checkpoint = tf.train.Checkpoint(compression_state=TFCompressionStateLoader())
checkpoint.restore(path_to_checkpoint)
compression_state = checkpoint.compression_state.state
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state)
checkpoint = tf.train.Checkpoint(model=model,
...)
checkpoint.restore(path_to_checkpoint)
#! [load_checkpoint]