Dassl.pytorch/dassl/engine/dg/domain_mix.py (62 lines of code) (raw):

import torch from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy __all__ = ["DomainMix"] @TRAINER_REGISTRY.register() class DomainMix(TrainerX): """DomainMix. Dynamic Domain Generalization. https://github.com/MetaVisionLab/DDG """ def __init__(self, cfg): super(DomainMix, self).__init__(cfg) self.mix_type = cfg.TRAINER.DOMAINMIX.TYPE self.alpha = cfg.TRAINER.DOMAINMIX.ALPHA self.beta = cfg.TRAINER.DOMAINMIX.BETA self.dist_beta = torch.distributions.Beta(self.alpha, self.beta) def forward_backward(self, batch): images, label_a, label_b, lam = self.parse_batch_train(batch) output = self.model(images) loss = lam * F.cross_entropy( output, label_a ) + (1-lam) * F.cross_entropy(output, label_b) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label_a)[0].item() } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): images = batch["img"] target = batch["label"] domain = batch["domain"] images = images.to(self.device) target = target.to(self.device) domain = domain.to(self.device) images, target_a, target_b, lam = self.domain_mix( images, target, domain ) return images, target_a, target_b, lam def domain_mix(self, x, target, domain): lam = ( self.dist_beta.rsample((1, )) if self.alpha > 0 else torch.tensor(1) ).to(x.device) # random shuffle perm = torch.randperm(x.size(0), dtype=torch.int64, device=x.device) if self.mix_type == "crossdomain": domain_list = torch.unique(domain) if len(domain_list) > 1: for idx in domain_list: cnt_a = torch.sum(domain == idx) idx_b = (domain != idx).nonzero().squeeze(-1) cnt_b = idx_b.shape[0] perm_b = torch.ones(cnt_b).multinomial( num_samples=cnt_a, replacement=bool(cnt_a > cnt_b) ) perm[domain == idx] = idx_b[perm_b] elif self.mix_type != "random": raise NotImplementedError( f"Chooses {'random', 'crossdomain'}, but got {self.mix_type}." ) mixed_x = lam*x + (1-lam) * x[perm, :] target_a, target_b = target, target[perm] return mixed_x, target_a, target_b, lam