import torch
from mqbench.fake_quantize.quantize_base import QuantizeBase
from mqbench.utils.hook import PerChannelLoadHook
_version_under_1100 = int(torch.__version__.split('.')[1]) < 10
[docs]class FixedFakeQuantize(QuantizeBase):
"""This is actually torch.quantization.FakeQuantize.
"""
def __init__(self, observer, **observer_kwargs):
super(FixedFakeQuantize, self).__init__(observer, **observer_kwargs)
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
self.load_state_dict_hook = PerChannelLoadHook(self)
[docs] def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
if self.scale.shape != _scale.shape:
self.scale.resize_(_scale.shape)
self.zero_point.resize_(_zero_point.shape)
self.scale.copy_(_scale)
self.zero_point.copy_(_zero_point)
if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale,
self.zero_point.long() if _version_under_1100 else self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, self.scale.item(), int(self.zero_point.item()),
self.quant_min, self.quant_max)
return X
def _save_to_state_dict(self, destination, prefix, keep_vars):
# We cannot currently register scalar values as buffers, so need to manually
# specify serialization here.
super(FixedFakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'scale'] = self.scale
destination[prefix + 'zero_point'] = self.zero_point
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Removing this function throws an error that the the size of the loaded tensor does not match the original size
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
# Custom handling to allow loading scale and zero_point
# of size N into uninitialized buffers of size 0. The
# buffers are resized here, and the values are copied in
# the default state_dict loading code of the parent.
if name == 'scale':
self.scale.resize_(val.shape)
else:
assert name == 'zero_point'
self.zero_point.resize_(val.shape)
# For torchscript module we need to update the attributes here since we do not
# call the `_load_from_state_dict` function defined module.py
if torch.jit.is_scripting():
if name == 'scale':
self.scale.copy_(val)
else:
assert name == 'zero_point'
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
super(FixedFakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)