import torch
import torch.nn.intrinsic.qat as nniqat
from torch.fx import GraphModule, Node
from torch import fx, nn
from torch.nn import Module
USE_LINK = False
USE_DDP = False
__all__ = ['ptq_reconstruction']
try:
import spring.linklink as link
if not link.is_initialized():
link.initialize()
USE_LINK = True
except (ModuleNotFoundError, AssertionError):
import torch.distributed as dist
if torch.distributed.is_initialized():
USE_DDP = True
import numpy as np
from typing import List
from mqbench.utils.logger import logger
from mqbench.utils.hook import DataSaverHook, StopForwardException
from mqbench.utils import deepcopy_graphmodule, deepcopy_mixedmodule, topology_order, getitem2node
from mqbench.utils.utils import _fix_succ_recursivly
from mqbench.utils.state import enable_quantization, disable_all
import mqbench.nn.intrinsic.qat as qnniqat
_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
_FUSED_TYPE = (nniqat.ConvBnReLU2d, nniqat.ConvBn2d, qnniqat.ConvFreezebn2d, qnniqat.ConvFreezebnReLU2d)
_WEIGHTS_MODULE_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
def node2modules(name2modules, nodes):
modules = dict()
for node in nodes:
if node.target in name2modules:
modules[node] = name2modules[node.target]
return modules
def qnode2fpnode(quant_modules, fp32_modules):
quant_named_nodes = {node.target: node for node in quant_modules}
fp32_named_nodes = {node.target: node for node in fp32_modules}
qnode2fpnode_dict = {quant_named_nodes[key]: fp32_named_nodes[key] for key in quant_named_nodes}
return qnode2fpnode_dict
def layer_has_weights(nodes, modules):
has_weights = False
for node in nodes:
if node in modules:
if isinstance(modules[node], _WEIGHTS_MODULE_TYPE):
has_weights = True
break
return has_weights
def lp_loss(pred, tgt, p=2.0):
"""
loss function measured in L_p Norm
"""
return (pred - tgt).abs().pow(p).sum(1).mean()
def to_device(data, device='cpu'):
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
for key in data:
data[key] = to_device(data[key], device)
return data
elif isinstance(data, list):
for idx, _ in enumerate(data):
data[idx] = to_device(data[idx], device)
return data
else:
return data
def tensor_detach(data):
if isinstance(data, torch.Tensor):
return data.detach()
elif isinstance(data, dict):
for key in data:
data[key] = tensor_detach(data[key])
return data
elif isinstance(data, list):
data = [tensor_detach(dat) for dat in data]
else:
return data
def save_inp_oup_data(model: GraphModule, inp_module: Module, oup_module: Module, cali_data: list, store_inp=True, store_oup=True,
keep_gpu: bool = True):
"""
Save input data and output data of a particular layer/block over calibration dataset.
:param fp_model: fp_model
:param quant_model: quant_model
:param cali_data: calibration data set
:param keep_gpu: put saved data on GPU for faster optimization
:return: input and output data
"""
device = next(model.parameters()).device
if store_inp:
assert inp_module is not None
inp_saver = DataSaverHook(store_input=store_inp, store_output=False, stop_forward=(not store_oup))
inp_handle = inp_module.register_forward_hook(inp_saver)
if store_oup:
assert oup_module is not None
oup_saver = DataSaverHook(store_input=False, store_output=store_oup, stop_forward=True)
oup_handle = oup_module.register_forward_hook(oup_saver)
cached = ([], [])
with torch.no_grad():
for batch in cali_data:
try:
_ = model(to_device(batch, device))
except StopForwardException:
pass
if store_inp:
if keep_gpu:
cached[0].append([tensor_detach(inp) for inp in inp_saver.input_store])
else:
cached[0].append([to_device(tensor_detach(inp), 'cpu') for inp in inp_saver.input_store]) # tuple/list one
if store_oup:
if keep_gpu:
cached[1].append(tensor_detach(oup_saver.output_store))
else:
cached[1].append(to_device(tensor_detach(oup_saver.output_store), 'cpu'))
if store_inp:
inp_handle.remove()
if store_oup:
oup_handle.remove()
torch.cuda.empty_cache()
return cached
class LinearTempDecay:
def __init__(self, t_max=10000, warm_up=0.2, start_b=20, end_b=2):
self.t_max = t_max
self.start_decay = warm_up * t_max
self.start_b = start_b
self.end_b = end_b
def __call__(self, t):
if t < self.start_decay:
return self.start_b
elif t > self.t_max:
return self.end_b
else:
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))
class CosineTempDecay:
def __init__(self, t_max=10000, warm_up=0.2, start_b=20, end_b=2):
self.t_max = t_max
self.start_decay = warm_up * t_max
self.start_b = start_b
self.end_b = end_b
def __call__(self, t):
if t < self.start_decay:
return self.start_b
elif t > self.t_max:
return self.end_b
else:
rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
return self.end_b + 0.5 * (self.start_b - self.end_b) * (1 + np.cos(rel_t * np.pi))
class LossFunction:
r'''loss function to calculate mse reconstruction loss and relaxation loss
use some tempdecay to balance the two losses.
'''
def __init__(self,
subgraph: Module,
weight: float = 1.,
max_count: int = 10000,
b_range: tuple = (20, 2),
warm_up: float = 0.0,
p: float = 2.):
self.subgraph = subgraph
self.weight = weight
self.loss_start = max_count * warm_up
self.p = p
self.temp_decay = LinearTempDecay(max_count, warm_up=warm_up,
start_b=b_range[0], end_b=b_range[1])
self.count = 0
def __call__(self, pred, tgt):
"""
Compute the total loss for adaptive rounding:
rec_loss is the quadratic output reconstruction loss, round_loss is
a regularization term to optimize the rounding policy
:param pred: output from quantized model
:param tgt: output from FP model
:return: total loss function
"""
self.count += 1
rec_loss = lp_loss(pred, tgt, p=self.p)
b = self.temp_decay(self.count)
if self.count < self.loss_start:
round_loss = 0
else:
round_loss = 0
for layer in self.subgraph.modules():
if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
round_vals = layer.weight_fake_quant.rectified_sigmoid()
round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum()
total_loss = rec_loss + round_loss
if self.count % 500 == 0:
logger.info('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format(
float(total_loss), float(rec_loss), float(round_loss), b, self.count))
return total_loss
def _flatten_args(node):
flattned_args = []
if isinstance(node, dict):
for v in node.values():
flattned_args.extend(_flatten_args(v))
elif isinstance(node, tuple) or isinstance(node, list):
for n in node:
flattned_args.extend(_flatten_args(n))
else:
flattned_args.extend([node])
return flattned_args
def find_used_times(nodes, target):
used = len([_node for _node in target.users if _node in nodes])
return used
def find_cur_node(layer_node_list):
node_list = []
used_later = []
for idx, node in enumerate(layer_node_list):
for _node in layer_node_list[idx + 1:]:
if node in _flatten_args(_node.args):
used_later.append(node)
break
not_used_later = [node for node in layer_node_list if node not in used_later]
single_branch = dict()
for node in not_used_later:
single_branch[node] = set([node])
q = [node]
while True:
now_args = sum([_flatten_args(_node.args) for _node in q], [])
p = [_node for _node in now_args if isinstance(_node, torch.fx.Node) and find_used_times(layer_node_list, _node) == 1]
single_branch[node] = single_branch[node].union(set(p))
if len(p) == 0:
break
else:
q = p
for node in layer_node_list:
if node.op == 'call_function' or node.op == 'call_method':
continue
if node not in used_later:
break
unwanted = set()
for key in single_branch:
if key is node:
continue
else:
unwanted = unwanted.union(single_branch[key])
layer_node_list = [_node for _node in layer_node_list if _node not in unwanted]
for _node in layer_node_list:
node_list.append(_node)
if _node is node:
return node_list
def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config):
global USE_LINK
global USE_DDP
device = next(subgraph.parameters()).device
w_para, a_para = [], []
w_opt, w_scheduler = None, None
if hasattr(config, 'scale_lr'):
a_para = []
for name, layer in subgraph.named_modules():
if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
weight_quantizer = layer.weight_fake_quant
# assert isinstance(weight_quantizer, adaround_quantizer) is True
weight_quantizer.init(layer.weight.data, config.round_mode)
w_para += [weight_quantizer.alpha]
if isinstance(layer, torch.quantization.FakeQuantizeBase) and 'post_act_fake_quantize' in name:
if hasattr(config, 'scale_lr'):
logger.info('learn the scale for {}'.format(name))
a_para += [layer.scale]
layer.prob = config.prob
if len(a_para) != 0:
a_opt = torch.optim.Adam(a_para, lr=config.scale_lr)
a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=config.max_count, eta_min=0.)
else:
a_opt, a_scheduler = None, None
w_opt = torch.optim.Adam(w_para)
loss_func = LossFunction(subgraph=subgraph, weight=config.weight, max_count=config.max_count, b_range=config.b_range,
warm_up=config.warm_up)
if any([USE_DDP, USE_LINK]):
world_size = link.get_world_size() if USE_LINK else dist.get_world_size()
else:
world_size = 1
logger.info('The world size is {}.'.format(world_size))
'''start training'''
logger.info('start tuning by adaround')
if config.prob < 1.0:
# cache inps: drop x args x batch x data
sz = len(cached_inps[0][0])
num_args = len(cached_inps[0])
else:
# cache inps: args x batch x data
sz = len(cached_inps[0])
num_args = len(cached_inps)
for i in range(config.max_count):
idx = np.random.randint(0, sz)
cur_args = []
for a in range(num_args):
if config.prob < 1.0:
cur_inp = to_device(cached_inps[0][a][idx], device)
cur_sym = to_device(cached_inps[1][a][idx], device)
cur_inp = torch.where(torch.rand_like(cur_inp) < config.prob, cur_inp, cur_sym)
else:
cur_inp = to_device(cached_inps[a][idx], device)
cur_args.append(cur_inp)
cur_args = tuple(cur_args)
cur_out = to_device(cached_oups[idx], device)
if a_opt:
a_opt.zero_grad()
w_opt.zero_grad()
out_quant = subgraph(*cur_args)
err = loss_func(out_quant, cur_out)
err /= world_size
err.backward()
if world_size > 1:
for param in w_para:
if USE_LINK:
link.allreduce(param.grad.data)
elif USE_DDP:
dist.all_reduce(param.grad.data)
w_opt.step()
if a_opt:
a_opt.step()
if w_scheduler:
w_scheduler.step()
if a_scheduler:
a_scheduler.step()
torch.cuda.empty_cache()
for name, layer in subgraph.named_modules():
if isinstance(layer, _FUSED_TYPE):
# We need to do bn fold simulation here.
weight_quantizer = layer.weight_fake_quant
scale_factor = layer.bn.weight / torch.sqrt(layer.bn.running_var + layer.bn.eps)
merged_rounded_weight = weight_quantizer.get_hard_value(
layer.weight.data * scale_factor.reshape([-1] + [1] * (len(layer.weight.shape) - 1)))
layer.weight.data = merged_rounded_weight / scale_factor.reshape([-1] + [1] * (len(merged_rounded_weight.shape) - 1))
weight_quantizer.adaround = False
elif isinstance(layer, _ADAROUND_SUPPORT_TYPE):
assert not hasattr(layer, 'bn'), 'Layer {} with type {} has BN ! Should not reach here.'.format(name, type(layer))
weight_quantizer = layer.weight_fake_quant
layer.weight.data = weight_quantizer.get_hard_value(layer.weight.data)
weight_quantizer.adaround = False
if isinstance(layer, torch.quantization.FakeQuantizeBase) and 'post_act_fake_quantize' in name:
layer.prob = 1.0 # recover to promise that drop activation quantization only occurs at reconstruction phase
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], output: fx.Node, g2node: dict):
"""
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
"""
new_graph = fx.Graph()
env = dict()
inp_lst = []
for node in nodes:
for arg in _flatten_args(node.args):
if isinstance(arg, torch.fx.Node):
if arg not in nodes and arg not in inp_lst:
inp_lst.append(node)
if node in g2node:
arg_name = g2node[node].name
else:
arg_name = node.name
new_node = new_graph.placeholder(arg_name)
env[node] = new_node
break
for node in nodes:
if node in inp_lst:
continue
if node in g2node:
node = g2node[node]
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
# create this or there will not be return value
new_graph.output(env[output])
new_graph.lint()
return fx.GraphModule(orig_module, new_graph)
def find_num_nodes(nodes):
num = 0
for node in nodes:
if isinstance(node, Node):
num += 1
return num
# Recommend: log this to check if the layer is right. You can define your own layer manually or automatically like this
# extract the linked-list/single-chain
def extract_layer(node, fp32_modules):
layer_node_list = []
cur_node = node
is_next_block = False # check whether stoped by a block
while True:
logger.debug('cur_node in layer is {}'.format(cur_node))
layer_node_list.append(cur_node) # valid node here
stop = (len(cur_node.users) == 0)
for user in cur_node.users:
if user.target == 'update':
continue
if user.op == 'call_module' and isinstance(
fp32_modules[user], _ADAROUND_SUPPORT_TYPE):
stop = True
# TODO: only short-cut here, consider more here
# TODO: can also use un/completed to check here.
if ('add' in user.name
and user.op in ['call_function', 'call_method']):
stop = True
if user.op == 'output':
is_next_block, stop = True, True
if stop:
break
cur_node = list(cur_node.users.keys())[0]
if find_num_nodes(cur_node.users) > 1:
is_next_block = True
return layer_node_list, is_next_block
# Recommend: log this to check if the block is right. You can define your own block manually or automatically like this
# extract the block one such as short-cut
def extract_block(input_nodes, fp32_modules, depth=0):
if depth > 2:
# stack 2 or 3 layers for no short-cut structure
return []
layer_node_list = []
is_block = False
cnt = dict()
q, p = [], [] # q records the completed node, p records the uncompleted nodes
cur_node = None
for input in input_nodes:
for user in input.users:
if user not in cnt:
cnt[user] = find_num_nodes(user.args)
if cnt[user] > 1:
is_block = True
p.append(user)
cnt[user] -= 1
if cnt[user] == 0:
q.append(user)
p.remove(user)
while len(q) != 0:
cur_node = q.pop(0) # valid node here
logger.debug('cur node is {}'.format(cur_node))
if cur_node.target == 'update':
continue
if len(p) == 0 and len(q) == 0:
break
layer_node_list.append(cur_node)
for user in cur_node.users:
if user not in cnt:
cnt[user] = find_num_nodes(user.args)
if cnt[user] > 1:
is_block = True
p.append(user)
cnt[user] -= 1
if cnt[user] == 0:
q.append(user)
p.remove(user)
logger.debug('uncompleted nodes are {}'.format(p))
if not cur_node:
return layer_node_list
exp_nodes, is_next_block = extract_layer(cur_node, fp32_modules)
if is_block or is_next_block:
return layer_node_list + exp_nodes
else:
return layer_node_list + exp_nodes + extract_block(
[exp_nodes[-1]], fp32_modules, depth + 1)
[docs]def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_module_list: list = None):
r"""
Reconsturction for AdaRound, BRECQ, QDrop.
Basic optimization objective:
.. math::
\mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}),
\tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right)
where :math:`h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)`, and :math:`f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}`. By annealing on :math:`\beta`, the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase.
Args:
model (torch.nn.Module): a prepared GraphModule to do PTQ
cali_data (List): a list of calibration tensor
config (dict): a config for PTQ reconstruction
graph_module_list (list): a list of model's children modules which need quantization. if this is used, the model is partial quantized; if not, the model is fully quantized.
>>> sample config : {
pattern: block (str, Available options are [layer, block].)
scale_lr: 4.0e-5 (learning rate for learning step size of activation)
warm_up: 0.2 (0.2 * max_count iters without regularization to floor or ceil)
weight: 0.01 (loss weight for regularization item)
max_count: 20000 (optimization iteration)
b_range: [20,2] (beta decaying range )
keep_gpu: True (calibration data restore in gpu or cpu)
round_mode: learned_hard_sigmoid (ways to reconstruct the weight, currently only support learned_hard_sigmoid)
prob: 0.5 (dropping probability of QDROP)
}
"""
# assert model is on cuda
if not config.keep_gpu:
cali_data = [to_device(inp, 'cpu') for inp in cali_data]
'''set state first'''
fp32_model = model
fp32_model.eval()
if graph_module_list is None:
assert isinstance(fp32_model, torch.fx.GraphModule)
quant_model = deepcopy_graphmodule(model)
nodes = list(quant_model.graph.nodes)
g2node = getitem2node(quant_model)
fp32_modules = node2modules(dict(fp32_model.named_modules()), fp32_model.graph.nodes)
quant_modules = node2modules(dict(quant_model.named_modules()), quant_model.graph.nodes)
topology_order_by_node = topology_order(quant_model)
else:
quant_model = deepcopy_mixedmodule(model, graph_module_list)
nodes = []
g2node = dict()
fp32_modules = dict()
quant_modules = dict()
topology_order_by_node = {}
topo_cnt = 0
for mname in graph_module_list:
child = getattr(quant_model, mname)
assert isinstance(child, torch.fx.GraphModule)
nodes += list(child.graph.nodes)
g2node.update(getitem2node(child))
for mname in graph_module_list:
fp_child = getattr(fp32_model, mname)
q_child = getattr(quant_model, mname)
# note: the nodes we use is from the quant model, so build q_node2fp_module, rather than fp2fp.
fp_modules = node2modules(dict(fp_child.named_modules()), q_child.graph.nodes)
q_modules = node2modules(dict(q_child.named_modules()), q_child.graph.nodes)
fp32_modules.update(fp_modules)
quant_modules.update(q_modules)
child_topo = topology_order(q_child)
for k in child_topo:
child_topo[k] += topo_cnt
topology_order_by_node.update(child_topo)
topo_cnt += len(topology_order_by_node)
qnode2fpnode_dict = qnode2fpnode(quant_modules, fp32_modules)
quant_model.eval()
disable_all(fp32_model)
enable_quantization(quant_model)
torch.cuda.empty_cache()
checked_nodes = dict()
for node in nodes:
if 'exclude_node_prefix' in config:
cont = False
for prefix in config['exclude_node']:
if node.name.startswith(prefix):
cont = True
break
if cont:
logger.info(f'Exclude node {node}')
continue
if node in checked_nodes:
continue
if node.op == "call_module" and isinstance(quant_modules[node], _ADAROUND_SUPPORT_TYPE):
logger.info('prepare {} reconstruction for {}'.format(config.pattern, node))
if config.pattern == 'layer':
layer_node_list, _ = extract_layer(node, quant_modules)
elif config.pattern == 'block':
layer_node_list = extract_block(node.all_input_nodes, quant_modules)
else:
raise NotImplementedError
# if the update is not used in the block, remove it
if not all([n.target != 'update' for n in layer_node_list]):
remove_nodes = []
for idx, n in enumerate(layer_node_list):
if n.target == 'update':
src = n.args[0]
remove = True
for _idx in range(idx + 1, len(layer_node_list)):
if src in _flatten_args(
layer_node_list[_idx].args):
remove = False
break
if remove:
remove_nodes.append(n)
layer_node_list = [n for n in layer_node_list if n not in remove_nodes]
missing_inputs = []
for _node in layer_node_list:
for arg in _flatten_args(_node.args):
if isinstance(arg, torch.fx.Node):
if arg not in layer_node_list and arg not in missing_inputs:
missing_inputs.append(arg)
layer_node_list.extend(missing_inputs)
# replace getitem nodes into its source node
layer_node_list = [n if n not in g2node else g2node[n] for n in layer_node_list]
for _node in layer_node_list:
src = [arg for arg in _flatten_args(_node.args) if arg in g2node]
for arg in src:
_node.args = _fix_succ_recursivly(_node.args, arg, g2node[arg])
layer_node_list = sorted(layer_node_list, key=lambda x: topology_order_by_node[x])
layer_node_list = find_cur_node(layer_node_list)
if layer_has_weights(layer_node_list, quant_modules):
pass
else:
continue
logger.info('the node list is below!')
logger.info(layer_node_list)
fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]]
fp32_all_inps = []
quant_all_inps = []
fp32_final_oups = None
out_is_cached = False
for _node in layer_node_list:
if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]):
continue
else:
fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]]
quant_module = quant_modules[_node]
# fp32 inps: [out_b1, out_b2, ...]
_, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data,
store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu)
_, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu)
_, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data,
store_inp=False, store_oup=True, keep_gpu=config.keep_gpu)
fp32_all_inps.append(fp32_inps)
quant_all_inps.append(quant_inps)
if not out_is_cached:
fp32_final_oups = fp32_oups
out_is_cached = True
cached_inps = (quant_all_inps, fp32_all_inps) if config.prob < 1.0 else quant_all_inps
cached_oups = fp32_final_oups
quant_modules_by_name = dict()
for node in layer_node_list:
if node.op == 'call_module':
quant_modules_by_name[node.target] = quant_modules[node]
subgraph = extract_subgraph(quant_modules_by_name, layer_node_list,
layer_node_list[-1], g2node)
logger.info(subgraph.code)
subgraph_reconstruction(subgraph, cached_inps, cached_oups, config)
for x in layer_node_list:
checked_nodes[x] = True
disable_all(quant_model)
for node in checked_nodes:
if node.op == 'call_module':
enable_quantization(quant_modules[node])
logger.info(f'set the node {node.target} in quant')
return quant_model