Source code for mqbench.nn.intrinsic.modules.fused

from torch.nn.intrinsic import _FusedModule
from torch.nn import Linear, BatchNorm1d, BatchNorm2d, ReLU, ConvTranspose2d, Conv2d
from mqbench.nn.modules import FrozenBatchNorm2d

[docs]class LinearBn1d(_FusedModule): r"""This is a sequential container which calls the Linear and Batch Norm 1d modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, linear, bn): assert type(linear) == Linear and type(bn) == BatchNorm1d, \ 'Incorrect types for input modules{}{}'.format( type(linear), type(bn)) super().__init__(linear, bn)
[docs]class ConvTransposeBn2d(_FusedModule): def __init__(self, deconv, bn): assert type(deconv) == ConvTranspose2d and type(bn) == BatchNorm2d, \ 'Incorrect types for input modules{}{}'.format( type(deconv), type(bn)) super().__init__(deconv, bn)
[docs]class ConvTransposeBnReLU2d(_FusedModule): def __init__(self, deconv, bn, relu): assert type(deconv) == ConvTranspose2d and type(bn) == BatchNorm2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}{}'.format( type(deconv), type(bn), type(relu)) super().__init__(deconv, bn, relu)
[docs]class ConvTransposeReLU2d(_FusedModule): def __init__(self, deconv, relu): assert type(deconv) == ConvTranspose2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(deconv), type(relu)) super().__init__(deconv, relu)
[docs]class ConvBn2d(_FusedModule): def __init__(self, conv, bn): assert type(conv) == Conv2d and type(bn) == BatchNorm2d, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(bn)) super().__init__(conv, bn)
[docs]class ConvBnReLU2d(_FusedModule): def __init__(self, conv, bn, relu): assert type(conv) == Conv2d and type(bn) == BatchNorm2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}{}'.format( type(conv), type(bn), type(relu)) super().__init__(conv, bn, relu)
[docs]class ConvReLU2d(_FusedModule): def __init__(self, conv, relu): assert type(conv) == Conv2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(relu)) super().__init__(conv, relu)
[docs]class ConvFreezebn2d(_FusedModule): def __init__(self, conv, bn): assert type(conv) == Conv2d and type(bn) == FrozenBatchNorm2d, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(bn)) super().__init__(conv, bn)
[docs]class ConvFreezebnReLU2d(_FusedModule): def __init__(self, conv, bn, relu): assert type(conv) == Conv2d and type(bn) == FrozenBatchNorm2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}{}'.format( type(conv), type(bn), type(relu)) super().__init__(conv, bn, relu)
[docs]class ConvTransposeFreezebn2d(_FusedModule): def __init__(self, deconv, bn): assert type(deconv) == ConvTranspose2d and type(bn) == FrozenBatchNorm2d, \ 'Incorrect types for input modules{}{}'.format( type(deconv), type(bn)) super().__init__(deconv, bn)
[docs]class ConvTransposeFreezebnReLU2d(_FusedModule): def __init__(self, deconv, bn, relu): assert type(deconv) == ConvTranspose2d and type(bn) == FrozenBatchNorm2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}{}'.format( type(deconv), type(bn), type(relu)) super().__init__(deconv, bn, relu)