Dassl.pytorch/dassl/modeling/ops/efdmix.py (75 lines of code) (raw):

import random from contextlib import contextmanager import torch import torch.nn as nn def deactivate_efdmix(m): if type(m) == EFDMix: m.set_activation_status(False) def activate_efdmix(m): if type(m) == EFDMix: m.set_activation_status(True) def random_efdmix(m): if type(m) == EFDMix: m.update_mix_method("random") def crossdomain_efdmix(m): if type(m) == EFDMix: m.update_mix_method("crossdomain") @contextmanager def run_without_efdmix(model): # Assume MixStyle was initially activated try: model.apply(deactivate_efdmix) yield finally: model.apply(activate_efdmix) @contextmanager def run_with_efdmix(model, mix=None): # Assume MixStyle was initially deactivated if mix == "random": model.apply(random_efdmix) elif mix == "crossdomain": model.apply(crossdomain_efdmix) try: model.apply(activate_efdmix) yield finally: model.apply(deactivate_efdmix) class EFDMix(nn.Module): """EFDMix. Reference: Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022. """ def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): """ Args: p (float): probability of using MixStyle. alpha (float): parameter of the Beta distribution. eps (float): scaling parameter to avoid numerical issues. mix (str): how to mix. """ super().__init__() self.p = p self.beta = torch.distributions.Beta(alpha, alpha) self.eps = eps self.alpha = alpha self.mix = mix self._activated = True def __repr__(self): return ( f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" ) def set_activation_status(self, status=True): self._activated = status def update_mix_method(self, mix="random"): self.mix = mix def forward(self, x): if not self.training or not self._activated: return x if random.random() > self.p: return x B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3) x_view = x.view(B, C, -1) value_x, index_x = torch.sort(x_view) # sort inputs lmda = self.beta.sample((B, 1, 1)) lmda = lmda.to(x.device) if self.mix == "random": # random shuffle perm = torch.randperm(B) elif self.mix == "crossdomain": # split into two halves and swap the order perm = torch.arange(B - 1, -1, -1) # inverse index perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(perm_b.shape[0])] perm_a = perm_a[torch.randperm(perm_a.shape[0])] perm = torch.cat([perm_b, perm_a], 0) else: raise NotImplementedError inverse_index = index_x.argsort(-1) x_view_copy = value_x[perm].gather(-1, inverse_index) new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda) return new_x.view(B, C, W, H)