mqbench package

Subpackages

Submodules

mqbench.advanced_ptq

mqbench.advanced_ptq.ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_module_list: Optional[list] = None)[source]

Reconsturction for AdaRound, BRECQ, QDrop. Basic optimization objective:

\[ \begin{align}\begin{aligned}\mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}),\\\tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right)\end{aligned}\end{align} \]

where \(h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)\), and \(f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}\). By annealing on \(\beta\), the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase.

Parameters:
  • model (torch.nn.Module) – a prepared GraphModule to do PTQ

  • cali_data (List) – a list of calibration tensor

  • config (dict) – a config for PTQ reconstruction

  • graph_module_list (list) – a list of model’s children modules which need quantization. if this is used, the model is partial quantized; if not, the model is fully quantized.

>>> sample config : {
        pattern: block (str, Available options are [layer, block].)
        scale_lr: 4.0e-5 (learning rate for learning step size of activation)
        warm_up: 0.2 (0.2 * max_count iters without regularization to floor or ceil)
        weight: 0.01 (loss weight for regularization item)
        max_count: 20000 (optimization iteration)
        b_range: [20,2] (beta decaying range )
        keep_gpu: True (calibration data restore in gpu or cpu)
        round_mode: learned_hard_sigmoid (ways to reconstruct the weight, currently only support learned_hard_sigmoid)
        prob: 0.5 (dropping probability of QDROP)
    }

mqbench.convert_deploy

mqbench.convert_onnx

mqbench.custom_quantizer

mqbench.custom_symbolic_opset

mqbench.custom_symbolic_opset.fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max)[source]
mqbench.custom_symbolic_opset.fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max)[source]

mqbench.fuser_method_mappings

class mqbench.fuser_method_mappings.ConvExtendBnReLUFusion(quantizer: Any, node: Node)[source]

Bases: ConvBNReLUFusion

mqbench.fuser_method_mappings.fuse_conv_freezebn(conv, bn)[source]
mqbench.fuser_method_mappings.fuse_conv_freezebn_relu(conv, bn, relu)[source]
mqbench.fuser_method_mappings.fuse_deconv_bn(deconv, bn)[source]
mqbench.fuser_method_mappings.fuse_deconv_bn_relu(deconv, bn, relu)[source]
mqbench.fuser_method_mappings.fuse_deconv_freezebn(deconv, bn)[source]
mqbench.fuser_method_mappings.fuse_deconv_freezebn_relu(deconv, bn, relu)[source]
mqbench.fuser_method_mappings.fuse_linear_bn(linear, bn)[source]

Given the linear and bn modules, fuses them and returns the fused module

Parameters:
  • conv – Module instance of type Linear

  • bn – Spatial BN instance that needs to be fused with the conv

Examples:

>>> m1 = nn.Linear(10, 20)
>>> b1 = nn.BatchNorm1d(20)
>>> m2 = fuse_linear_bn(m1, b1)

mqbench.fusion_method

mqbench.observer

class mqbench.observer.ClipStdObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, std_scale=2.6, factory_kwargs=None)[source]

Bases: ObserverBase

Clip std.

forward(x_orig)[source]

Records the running minimum and maximum of x.

max_val: Tensor
min_val: Tensor
class mqbench.observer.EMAMSEObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, p=2.0, ema_ratio=0.9, factory_kwargs=None)[source]

Bases: ObserverBase

Calculate mseobserver of whole calibration dataset.

forward(x_orig)[source]

Records the running minimum and maximum of x.

lp_loss(pred, tgt, dim=None)[source]

loss function measured in L_p Norm

max_val: Tensor
min_val: Tensor
mse(x: Tensor, x_min: Tensor, x_max: Tensor, iter=80)[source]
mse_perchannel(x: Tensor, x_min: Tensor, x_max: Tensor, iter=80, ch_axis=0)[source]
class mqbench.observer.EMAMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, ema_ratio=0.9, factory_kwargs=None)[source]

Bases: ObserverBase

Moving average min/max among batches.

forward(x_orig)[source]

Records the running minimum and maximum of x.

max_val: Tensor
min_val: Tensor
class mqbench.observer.EMAQuantileObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, ema_ratio=0.9, threshold=0.99999, bins=2048, factory_kwargs=None)[source]

Bases: ObserverBase

Moving average quantile among batches.

forward(x_orig)[source]

Records the running minimum and maximum of x.

max_val: Tensor
min_val: Tensor
class mqbench.observer.LSQObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None)[source]

Bases: ObserverBase

LSQ observer.

calculate_qparams()[source]

Calculates the quantization parameters.

forward(x_orig)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_val: Tensor
min_val: Tensor
class mqbench.observer.LSQPlusObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None)[source]

Bases: ObserverBase

LSQ+ observer.

calculate_qparams()[source]

Calculates the quantization parameters.

forward(x_orig)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_val: Tensor
min_val: Tensor
class mqbench.observer.MSEObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, p=2.0, factory_kwargs=None)[source]

Bases: ObserverBase

Calculate mseobserver of whole calibration dataset.

forward(x_orig)[source]

Records the running minimum and maximum of x.

lp_loss(pred, tgt, dim=None)[source]

loss function measured in L_p Norm

max_val: Tensor
min_val: Tensor
mse(x: Tensor, x_min: Tensor, x_max: Tensor, iter=80)[source]
mse_perchannel(x: Tensor, x_min: Tensor, x_max: Tensor, iter=80, ch_axis=0)[source]
class mqbench.observer.MinMaxFloorObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None)[source]

Bases: ObserverBase

Calculate minmax of whole calibration dataset with floor but round.

calculate_qparams()[source]

Calculates the quantization parameters.

forward(x_orig)[source]

Records the running minimum and maximum of x.

max_val: Tensor
min_val: Tensor
set_quant_type(qtype)[source]
class mqbench.observer.MinMaxObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None)[source]

Bases: ObserverBase

Calculate minmax of whole calibration dataset.

forward(x_orig)[source]

Records the running minimum and maximum of x.

max_val: Tensor
min_val: Tensor
class mqbench.observer.ObserverBase(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None)[source]

Bases: _ObserverBase

Support per-tensor / per-channel. dtype: quant min/max can be infered using dtype, we actually do not need this. qscheme: quantization scheme reduce_range: special for fbgemm to avoid overflow quant_min: fix point value min quant_max: fix point value max ch_axis: per-channel axis or per-tensor(-1) above is similiar to torch observer. pot_scale: indecate wheather scale is power of two.

calculate_qparams() Tuple[Tensor, Tensor][source]

Calculates the quantization parameters.

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

max_val: Tensor
min_val: Tensor
class mqbench.observer.PoTModeObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None)[source]

Bases: ObserverBase

Records the most frequent Potscale of x.

calculate_qparams()[source]

Calculates the quantization parameters.

eps: torch.Tensor
forward(x_orig)[source]

Records the running minimum and maximum of x.

max_val: Tensor
min_val: Tensor
set_quant_type(qtype)[source]
training: bool

mqbench.prepare_by_platform

mqbench.prepare_by_platform.prepare_by_platform(model: Module, deploy_backend: BackendType, prepare_custom_config_dict: Dict[str, Any] = {}, custom_tracer: Optional[Tracer] = None)[source]
Parameters:
  • model (torch.nn.Module) –

  • deploy_backend (BackendType) –

>>> prepare_custom_config_dict : {
        extra_qconfig_dict : Dict, Find explanations in get_qconfig_by_platform,
        extra_quantizer_dict: Extra params for quantizer.
        preserve_attr: Dict, Specify attribute of model which should be preserved
            after prepare. Since symbolic_trace only store attributes which is
            in forward. If model.func1 and model.backbone.func2 should be preserved,
            {"": ["func1"], "backbone": ["func2"] } should work.
        Attr below is inherited from Pytorch.
        concrete_args: Specify input for model tracing.
        extra_fuse_dict: Specify extra fusing patterns and functions.
    }

Module contents