Naive QAT

The quantization aware training only requires some additional operations compared to ordinary fine-tune.

1. Prepare FP32 model firstly.

import torchvision.models as models
from mqbench.convert_deploy import convert_deploy
from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType
from mqbench.utils.state import enable_calibration, enable_quantization

# first, initialize the FP32 model with pretrained parameters.
model = models.__dict__["resnet18"](pretrained=True)
model.train()

2. Choose your backend.

# backend options
backend = BackendType.Tensorrt
# backend = BackendType.SNPE
# backend = BackendType.PPLW8A16
# backend = BackendType.NNIE
# backend = BackendType.Vitis
# backend = BackendType.ONNX_QNN
# backend = BackendType.PPLCUDA
# backend = BackendType.OPENVINO
# backend = BackendType.Tengine_u8
# backend = BackendType.Tensorrt_NLP

3. Prepares to quantize the model.

# trace model and add quant nodes for model on backend
model = prepare_by_platform(model, backend)

# calibration loop
model.eval()
enable_calibration(model)
for i, batch in enumerate(data):
    # do forward procedures
    ...

# training loop
model.train()
enable_quantization(model)
for i, batch in enumerate(data):
    # do forward procedures
    ...

    # do backward and optimization
    ...

4. Export quantized model.

# define dummy data for model export.
input_shape={'data': [10, 3, 224, 224]}
convert_deploy(model, backend, input_shape)

Now you know how to conduct naive QAT with MQBench, if you want to know more about customize backend check Learn MQBench configuration.