Dassl.pytorch/dassl/engine/dg/ddaig.py (74 lines of code) (raw):

import torch from torch.nn import functional as F from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.modeling import build_network from dassl.engine.trainer import SimpleNet @TRAINER_REGISTRY.register() class DDAIG(TrainerX): """Deep Domain-Adversarial Image Generation. https://arxiv.org/abs/2003.06054. """ def __init__(self, cfg): super().__init__(cfg) self.lmda = cfg.TRAINER.DDAIG.LMDA self.clamp = cfg.TRAINER.DDAIG.CLAMP self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX self.warmup = cfg.TRAINER.DDAIG.WARMUP self.alpha = cfg.TRAINER.DDAIG.ALPHA def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) print("Building D") self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains) self.D.to(self.device) print("# params: {:,}".format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model("D", self.D, self.optim_D, self.sched_D) print("Building G") self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE) self.G.to(self.device) print("# params: {:,}".format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model("G", self.G, self.optim_G, self.sched_G) def forward_backward(self, batch): input, label, domain = self.parse_batch_train(batch) ############# # Update G ############# input_p = self.G(input, lmda=self.lmda) if self.clamp: input_p = torch.clamp( input_p, min=self.clamp_min, max=self.clamp_max ) loss_g = 0 # Minimize label loss loss_g += F.cross_entropy(self.F(input_p), label) # Maximize domain loss loss_g -= F.cross_entropy(self.D(input_p), domain) self.model_backward_and_update(loss_g, "G") # Perturb data with new G with torch.no_grad(): input_p = self.G(input, lmda=self.lmda) if self.clamp: input_p = torch.clamp( input_p, min=self.clamp_min, max=self.clamp_max ) ############# # Update F ############# loss_f = F.cross_entropy(self.F(input), label) if (self.epoch + 1) > self.warmup: loss_fp = F.cross_entropy(self.F(input_p), label) loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp self.model_backward_and_update(loss_f, "F") ############# # Update D ############# loss_d = F.cross_entropy(self.D(input), domain) self.model_backward_and_update(loss_d, "D") loss_summary = { "loss_g": loss_g.item(), "loss_f": loss_f.item(), "loss_d": loss_d.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def model_inference(self, input): return self.F(input)