import torch
import torch.nn as nn
from torch.quantization.fx.fusion_patterns import ConvBNReLUFusion, ModuleReLUFusion
from torch.quantization.fx.quantization_types import QuantizerCls
from torch.fx.graph import Node
import mqbench.nn as qnn
import mqbench.nn.intrinsic as qnni
import mqbench.nn.intrinsic.qat as qnniqat
from mqbench.utils.fusion import fuse_deconv_bn_eval
from mqbench.nn.modules import FrozenBatchNorm2d
[docs]class ConvExtendBnReLUFusion(ConvBNReLUFusion):
def __init__(self, quantizer: QuantizerCls, node: Node):
super(ConvBNReLUFusion, self).__init__(quantizer, node)
self.relu_node = None
self.bn_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
self.relu_node = node
assert isinstance(node.args[0], Node)
node = node.args[0]
assert node.op == 'call_module'
if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, FrozenBatchNorm2d]:
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
assert isinstance(node.args[0], Node)
node = node.args[0]
assert node.op == 'call_module'
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]
[docs]def fuse_linear_bn(linear, bn):
r"""Given the linear and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type Linear
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Linear(10, 20)
>>> b1 = nn.BatchNorm1d(20)
>>> m2 = fuse_linear_bn(m1, b1)
"""
assert linear.training == bn.training, \
"Linear and BN both must be in the same mode (train or eval)."
if linear.training:
assert bn.affine, 'Only support fusing BatchNorm1d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm1d with tracking_running_stats set to True'
return qnn.intrinsic.LinearBn1d(linear, bn)
else:
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
[docs]def fuse_deconv_bn(deconv, bn):
assert deconv.training == bn.training, \
'DeConv and BN must be in the same mode (train or eval)'
if deconv.training:
assert bn.num_features == deconv.out_channels, 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return qnni.ConvTransposeBn2d(deconv, bn)
else:
return fuse_deconv_bn_eval(deconv, bn)
[docs]def fuse_deconv_bn_relu(deconv, bn, relu):
assert deconv.training == bn.training == relu.training, \
"DeConv and BN both must be in the same mode (train or eval)."
if deconv.training:
assert bn.num_features == deconv.out_channels, 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return qnni.ConvTransposeBnReLU2d(deconv, bn, relu)
else:
return qnni.ConvTransposeReLU2d(fuse_deconv_bn_eval(deconv, bn), relu)
[docs]def fuse_conv_freezebn(conv, bn):
assert bn.training is False, "Freezebn must be eval."
if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return qnni.ConvFreezebn2d(conv, bn)
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)
[docs]def fuse_conv_freezebn_relu(conv, bn, relu):
assert conv.training == relu.training and bn.training is False, \
"Conv and relu both must be in the same mode (train or eval) and bn must be eval."
if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
return qnni.ConvFreezebnReLU2d(conv, bn, relu)
else:
fused_conv = nn.utils.fuse_conv_bn_eval(conv, bn)
return nn.intrinsic.ConvReLU2d(fused_conv, relu)
[docs]def fuse_deconv_freezebn(deconv, bn):
assert bn.training is False, "Freezebn must be eval."
if deconv.training:
assert bn.num_features == deconv.out_channels, 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return qnni.ConvTransposeFreezebn2d(deconv, bn)
else:
return fuse_deconv_bn_eval(deconv, bn)
[docs]def fuse_deconv_freezebn_relu(deconv, bn, relu):
assert deconv.training == relu.training and bn.training is False, \
"Conv and relu both must be in the same mode (train or eval) and bn must be eval."
if deconv.training:
assert bn.num_features == deconv.out_channels, 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return qnni.ConvTransposeFreezebnReLU2d(deconv, bn, relu)
else:
return qnni.ConvTransposeReLU2d(fuse_deconv_bn_eval(deconv, bn), relu)
fuse_custom_config_dict = {
"additional_fuser_method_mapping": {
(torch.nn.Linear, torch.nn.BatchNorm1d): fuse_linear_bn,
(torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d): fuse_deconv_bn,
(torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_deconv_bn_relu,
(torch.nn.ConvTranspose2d, torch.nn.ReLU): qnni.ConvTransposeReLU2d,
(nn.Conv2d, FrozenBatchNorm2d, nn.ReLU): fuse_conv_freezebn_relu,
(nn.Conv2d, FrozenBatchNorm2d): fuse_conv_freezebn,
(nn.ConvTranspose2d, FrozenBatchNorm2d, nn.ReLU): fuse_deconv_freezebn_relu,
(nn.ConvTranspose2d, FrozenBatchNorm2d): fuse_deconv_freezebn,
},
"additional_fusion_pattern": {
(torch.nn.BatchNorm1d, torch.nn.Linear):
ConvBNReLUFusion,
(torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d):
ConvBNReLUFusion,
(torch.nn.ReLU, torch.nn.ConvTranspose2d):
ConvBNReLUFusion,
(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d)):
ConvBNReLUFusion,
(torch.nn.functional.relu, torch.nn.ConvTranspose2d):
ConvBNReLUFusion,
(torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d)):
ConvBNReLUFusion,
(torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.Conv2d)):
ConvExtendBnReLUFusion,
(FrozenBatchNorm2d, torch.nn.Conv2d):
ConvExtendBnReLUFusion,
(torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.ConvTranspose2d)):
ConvExtendBnReLUFusion,
(FrozenBatchNorm2d, torch.nn.ConvTranspose2d):
ConvExtendBnReLUFusion,
},
"additional_qat_module_mappings": {
nn.ConvTranspose2d: qnn.qat.ConvTranspose2d,
qnni.LinearBn1d: qnniqat.LinearBn1d,
qnni.ConvTransposeBn2d: qnniqat.ConvTransposeBn2d,
qnni.ConvTransposeReLU2d: qnniqat.ConvTransposeReLU2d,
qnni.ConvTransposeBnReLU2d: qnniqat.ConvTransposeBnReLU2d,
qnni.ConvFreezebn2d: qnniqat.ConvFreezebn2d,
qnni.ConvFreezebnReLU2d: qnniqat.ConvFreezebnReLU2d,
qnni.ConvTransposeFreezebn2d: qnniqat.ConvTransposeFreezebn2d,
qnni.ConvTransposeFreezebnReLU2d: qnniqat.ConvTransposeFreezebnReLU2d,
nn.Embedding: qnn.qat.Embedding,
},
}
def _sort_fusion_patterns(pats):
keys = []
for key in pats.keys():
if pats[key] is ModuleReLUFusion:
keys.append(key)
for key in keys:
pats.move_to_end(key)
# Sinse additional_fuser_method_mapping will not be set because fuser.py:54
# do not pass this dict.
from torch.quantization.fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD
from torch.quantization.fx.pattern_utils import DEFAULT_FUSION_PATTERNS
from torch.quantization.quantization_mappings import DEFAULT_QAT_MODULE_MAPPINGS
DEFAULT_OP_LIST_TO_FUSER_METHOD.update(
fuse_custom_config_dict['additional_fuser_method_mapping'])
DEFAULT_FUSION_PATTERNS.update(
fuse_custom_config_dict['additional_fusion_pattern'])
# Make longer matched pattern prior.
# i.e. Conv + BN + Relu should match ConvBnRelu before BNRelu.
# Any thing registered in class ConvBNReLUFusion should be
# proir than class ModuleReLUFusion.
_sort_fusion_patterns(DEFAULT_FUSION_PATTERNS)
DEFAULT_QAT_MODULE_MAPPINGS.update(
fuse_custom_config_dict['additional_qat_module_mappings'])