Dassl.pytorch/dassl/engine/dg/daeldg.py (130 lines of code) (raw):

import torch import torch.nn as nn from dassl.data import DataManager from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.engine.trainer import SimpleNet from dassl.data.transforms import build_transform from dassl.modeling.ops.utils import create_onehot class Experts(nn.Module): def __init__(self, n_source, fdim, num_classes): super().__init__() self.linears = nn.ModuleList( [nn.Linear(fdim, num_classes) for _ in range(n_source)] ) self.softmax = nn.Softmax(dim=1) def forward(self, i, x): x = self.linears[i](x) x = self.softmax(x) return x @TRAINER_REGISTRY.register() class DAELDG(TrainerX): """Domain Adaptive Ensemble Learning. DG version: only use labeled source data. https://arxiv.org/abs/2003.07325. """ def __init__(self, cfg): super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.conf_thre = cfg.TRAINER.DAELDG.CONF_THRE def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler" assert len(cfg.TRAINER.DAELDG.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.DAELDG.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = dm.train_loader_x self.train_loader_u = dm.train_loader_u self.val_loader = dm.val_loader self.test_loader = dm.test_loader self.num_classes = dm.num_classes self.num_source_domains = dm.num_source_domains self.lab2cname = dm.lab2cname def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print("Building E") self.E = Experts(self.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print("# params: {:,}".format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model("E", self.E, self.optim_E, self.sched_E) def forward_backward(self, batch): parsed_data = self.parse_batch_train(batch) input, input2, label, domain = parsed_data input = torch.split(input, self.split_batch, 0) input2 = torch.split(input2, self.split_batch, 0) label = torch.split(label, self.split_batch, 0) domain = torch.split(domain, self.split_batch, 0) domain = [d[0].item() for d in domain] loss_x = 0 loss_cr = 0 acc = 0 feat = [self.F(x) for x in input] feat2 = [self.F(x) for x in input2] for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain): cr_s = [j for j in domain if j != i] # Learning expert pred_i = self.E(i, feat_i) loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean() expert_label_i = pred_i.detach() acc += compute_accuracy(pred_i.detach(), label_i.max(1)[1])[0].item() # Consistency regularization cr_pred = [] for j in cr_s: pred_j = self.E(j, feat2_i) pred_j = pred_j.unsqueeze(1) cr_pred.append(pred_j) cr_pred = torch.cat(cr_pred, 1) cr_pred = cr_pred.mean(1) loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean() loss_x /= self.n_domain loss_cr /= self.n_domain acc /= self.n_domain loss = 0 loss += loss_x loss += loss_cr self.model_backward_and_update(loss) loss_summary = { "loss_x": loss_x.item(), "acc": acc, "loss_cr": loss_cr.item() } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] input2 = batch["img2"] label = batch["label"] domain = batch["domain"] label = create_onehot(label, self.num_classes) input = input.to(self.device) input2 = input2.to(self.device) label = label.to(self.device) return input, input2, label, domain def model_inference(self, input): f = self.F(input) p = [] for k in range(self.num_source_domains): p_k = self.E(k, f) p_k = p_k.unsqueeze(1) p.append(p_k) p = torch.cat(p, 1) p = p.mean(1) return p