in Dassl.pytorch/dassl/engine/da/cdac.py [0:0]
def build_model(self):
cfg = self.cfg
# Custom LR Scheduler for CDAC
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len(self.train_loader_x)
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len(self.len_train_loader_u)
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(
len(self.train_loader_x), len(self.train_loader_u)
)
self.max_iter = self.max_epoch * self.num_batches
print("Max Iterations: %d" % self.max_iter)
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
custom_lr_F = partial(
custom_scheduler, max_iter=self.max_iter, init_lr=cfg.OPTIM.LR
)
self.sched_F = LambdaLR(self.optim_F, custom_lr_F)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building C")
self.C = Prototypes(self.F.fdim, self.num_classes)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
# Multiply the learning rate of C by lr_multi
for group_param in self.optim_C.param_groups:
group_param['lr'] *= self.lr_multi
custom_lr_C = partial(
custom_scheduler,
max_iter=self.max_iter,
init_lr=cfg.OPTIM.LR * self.lr_multi
)
self.sched_C = LambdaLR(self.optim_C, custom_lr_C)
self.register_model("C", self.C, self.optim_C, self.sched_C)