Dassl.pytorch/dassl/modeling/ops/mixstyle.py (78 lines of code) (raw):

import random from contextlib import contextmanager import torch import torch.nn as nn def deactivate_mixstyle(m): if type(m) == MixStyle: m.set_activation_status(False) def activate_mixstyle(m): if type(m) == MixStyle: m.set_activation_status(True) def random_mixstyle(m): if type(m) == MixStyle: m.update_mix_method("random") def crossdomain_mixstyle(m): if type(m) == MixStyle: m.update_mix_method("crossdomain") @contextmanager def run_without_mixstyle(model): # Assume MixStyle was initially activated try: model.apply(deactivate_mixstyle) yield finally: model.apply(activate_mixstyle) @contextmanager def run_with_mixstyle(model, mix=None): # Assume MixStyle was initially deactivated if mix == "random": model.apply(random_mixstyle) elif mix == "crossdomain": model.apply(crossdomain_mixstyle) try: model.apply(activate_mixstyle) yield finally: model.apply(deactivate_mixstyle) class MixStyle(nn.Module): """MixStyle. Reference: Zhou et al. Domain Generalization with MixStyle. ICLR 2021. """ def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): """ Args: p (float): probability of using MixStyle. alpha (float): parameter of the Beta distribution. eps (float): scaling parameter to avoid numerical issues. mix (str): how to mix. """ super().__init__() self.p = p self.beta = torch.distributions.Beta(alpha, alpha) self.eps = eps self.alpha = alpha self.mix = mix self._activated = True def __repr__(self): return ( f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" ) def set_activation_status(self, status=True): self._activated = status def update_mix_method(self, mix="random"): self.mix = mix def forward(self, x): if not self.training or not self._activated: return x if random.random() > self.p: return x B = x.size(0) mu = x.mean(dim=[2, 3], keepdim=True) var = x.var(dim=[2, 3], keepdim=True) sig = (var + self.eps).sqrt() mu, sig = mu.detach(), sig.detach() x_normed = (x-mu) / sig lmda = self.beta.sample((B, 1, 1, 1)) lmda = lmda.to(x.device) if self.mix == "random": # random shuffle perm = torch.randperm(B) elif self.mix == "crossdomain": # split into two halves and swap the order perm = torch.arange(B - 1, -1, -1) # inverse index perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(perm_b.shape[0])] perm_a = perm_a[torch.randperm(perm_a.shape[0])] perm = torch.cat([perm_b, perm_a], 0) else: raise NotImplementedError mu2, sig2 = mu[perm], sig[perm] mu_mix = mu*lmda + mu2 * (1-lmda) sig_mix = sig*lmda + sig2 * (1-lmda) return x_normed*sig_mix + mu_mix