Develop PTQ with MQBench

Assume we get a PTQ algorithm, which needs certain layers’ input and certain layers’ output for calibration. Also, some calib data provided. And thanks to the prepared model, we do not have to be disturbed by quant/calib/float mode choose. There are only a few steps to imply a PTQ in MQBench.

  1. A fake quantizer.

  2. Data hooks to get intra-network feature maps.

  3. A loss function used in calibration.

Like stated in Quantization Function, a self-defined quantizer may be required for the PTQ.

Usually PTQ will adjust the weight via some quantization affine and backward in calibration, whichs need intra-network feature maps. We provided a hook mqbench.utils.hooks.DataSaverHook to catch input/output of a certain module. Just call it with torch.nn.Module.register_forward_hook like this, and similarly you can catch the gradients input/outputs:

 1def save_inp_oup_data(model: GraphModule, module: Module, cali_data: list, store_inp=True, store_oup=True):
 2    assert (not store_inp or not store_oup)
 3    data_saver = DataSaverHook(store_input=store_inp, store_output=store_oup, stop_forward=True)
 4    handle = module.register_forward_hook(data_saver)
 5    cached = []
 6    with torch.no_grad():
 7        for batch in cali_data:
 8            try:
 9                _ = model(batch.to(device))
10            except StopForwardException:
11                pass
12            if store_inp:
13                cached.append([inp.detach() for inp in data_saver.input_store])
14            if store_oup:
15                cached.append(data_saver.output_store.detach())
16    handle.remove()
17    return cached

Then you can design the PTQ function like:

 1def PTQ(model, data, *args, **kwargs):
 2    ptq_model = deepcopy_graphmodule(model)
 3    # diable the original model's update
 4    model.eval()
 5    # turn the original model into float
 6    disable_all(model)
 7    # turn the ptq model into quant
 8    enable_quantization(ptq_model)
 9    nodes = list(model.graph.nodes)
10    modules = dict(model.named_modules())
11    quant_modules = dict(ptq_model.named_modules())
12    for node in nodes:
13        if node.op == "call_module" and isinstance(modules[node.target], _PTQ_SUPPORT_TYPE):
14            module = modules[node.target]
15            quant_module = quant_modules[node.target]
16            cached_oups = save_inp_oup_data(model, module, cali_data,
17                                            store_inp=False, store_oup=True)
18            cached_inps = save_inp_oup_data(quant_model, quant_module, cali_data,
19                                            store_inp=True, store_oup=False)
20            # this will update the quant_module's params
21            do_your_calibration(quant_module, cached_inps, cached_oups)
22    return quant_model