Dassl.pytorch/dassl/modeling/ops/conv.py (78 lines of code) (raw):

import torch.nn as nn from .attention import Attention __all__ = ["Conv2dDynamic"] class Conv2dDynamic(nn.Module): """Conv2dDynamic from `"Dynamic Domain Generalization" <https://github.com/MetaVisionLab/DDG>`_. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int, bias: bool = True, squeeze: int = None, attention_in_channels: int = None ) -> None: super(Conv2dDynamic, self).__init__() if kernel_size // 2 != padding: # Only when this condition is met, we can ensure that different # kernel_size can obtain feature maps of consistent size. # Let I, K, S, P, O: O = (I + 2P - K) // S + 1, if P = K // 2, then O = (I - K % 2) // S + 1 # This means that the output of two different Ks with the same parity can be made the same by adjusting P. raise ValueError("`padding` must be equal to `kernel_size // 2`.") if kernel_size % 2 == 0: raise ValueError( "Kernel_size must be odd now because the templates we used are odd (kernel_size=1)." ) self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.kernel_templates = nn.ModuleDict() self.kernel_templates["conv_nn"] = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=min(in_channels, out_channels), bias=bias ) self.kernel_templates["conv_11"] = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias ) self.kernel_templates["conv_n1"] = nn.Conv2d( in_channels, out_channels, kernel_size=(kernel_size, 1), stride=stride, padding=(padding, 0), bias=bias ) self.kernel_templates["conv_1n"] = nn.Conv2d( in_channels, out_channels, kernel_size=(1, kernel_size), stride=stride, padding=(0, padding), bias=bias ) self.attention = Attention( attention_in_channels if attention_in_channels else in_channels, 4, squeeze, bias=bias ) def forward(self, x, attention_x=None): attention_x = x if attention_x is None else attention_x y = self.attention(attention_x) out = self.conv(x) for i, template in enumerate(self.kernel_templates): out += self.kernel_templates[template](x) * y[:, i].view(-1, 1, 1, 1) return out