Source code for mqbench.fake_quantize.fixed

import torch

from mqbench.fake_quantize.quantize_base import QuantizeBase
from mqbench.utils.hook import PerChannelLoadHook

_version_under_1100 = int(torch.__version__.split('.')[1]) < 10

[docs]class FixedFakeQuantize(QuantizeBase): """This is actually torch.quantization.FakeQuantize. """ def __init__(self, observer, **observer_kwargs): super(FixedFakeQuantize, self).__init__(observer, **observer_kwargs) self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) self.register_buffer('zero_point', torch.tensor([0], self.load_state_dict_hook = PerChannelLoadHook(self)
[docs] def forward(self, X): if self.observer_enabled[0] == 1: self.activation_post_process(X.detach()) _scale, _zero_point = self.calculate_qparams() _scale, _zero_point =, if self.scale.shape != _scale.shape: self.scale.resize_(_scale.shape) self.zero_point.resize_(_zero_point.shape) self.scale.copy_(_scale) self.zero_point.copy_(_zero_point) if self.fake_quant_enabled[0] == 1: if self.is_per_channel: X = torch.fake_quantize_per_channel_affine( X, self.scale, self.zero_point.long() if _version_under_1100 else self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine( X, self.scale.item(), int(self.zero_point.item()), self.quant_min, self.quant_max) return X
[docs] @torch.jit.export def extra_repr(self): return 'fake_quant_enabled={}, observer_enabled={}, ' \ 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ 'scale={}, zero_point={}'.format( self.fake_quant_enabled, self.observer_enabled, self.quant_min, self.quant_max, self.dtype, self.qscheme, self.ch_axis, self.scale if self.ch_axis == -1 else 'List', self.zero_point if self.ch_axis == -1 else 'List')
def _save_to_state_dict(self, destination, prefix, keep_vars): # We cannot currently register scalar values as buffers, so need to manually # specify serialization here. super(FixedFakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) destination[prefix + 'scale'] = self.scale destination[prefix + 'zero_point'] = self.zero_point def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): # Removing this function throws an error that the the size of the loaded tensor does not match the original size # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. local_state = ['scale', 'zero_point'] for name in local_state: key = prefix + name if key in state_dict: val = state_dict[key] # Custom handling to allow loading scale and zero_point # of size N into uninitialized buffers of size 0. The # buffers are resized here, and the values are copied in # the default state_dict loading code of the parent. if name == 'scale': self.scale.resize_(val.shape) else: assert name == 'zero_point' self.zero_point.resize_(val.shape) # For torchscript module we need to update the attributes here since we do not # call the `_load_from_state_dict` function defined if torch.jit.is_scripting(): if name == 'scale': self.scale.copy_(val) else: assert name == 'zero_point' self.zero_point.copy_(val) elif strict: missing_keys.append(key) super(FixedFakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)