Dassl.pytorch/dassl/modeling/ops/dsbn.py (25 lines of code) (raw):

import torch.nn as nn class _DSBN(nn.Module): """Domain Specific Batch Normalization. Args: num_features (int): number of features. n_domain (int): number of domains. bn_type (str): type of bn. Choices are ['1d', '2d']. """ def __init__(self, num_features, n_domain, bn_type): super().__init__() if bn_type == "1d": BN = nn.BatchNorm1d elif bn_type == "2d": BN = nn.BatchNorm2d else: raise ValueError self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain)) self.valid_domain_idxs = list(range(n_domain)) self.n_domain = n_domain self.domain_idx = 0 def select_bn(self, domain_idx=0): assert domain_idx in self.valid_domain_idxs self.domain_idx = domain_idx def forward(self, x): return self.bn[self.domain_idx](x) class DSBN1d(_DSBN): def __init__(self, num_features, n_domain): super().__init__(num_features, n_domain, "1d") class DSBN2d(_DSBN): def __init__(self, num_features, n_domain): super().__init__(num_features, n_domain, "2d")