in main_fixmatch.py [0:0]
def train(train_loader_x, train_loader_u, model, optimizer, epoch, args):
batch_time = utils.AverageMeter('Time', ':6.3f')
data_time = utils.AverageMeter('Data', ':6.3f')
losses = utils.AverageMeter('Loss', ':.4e')
losses_x = utils.AverageMeter('Loss_x', ':.4e')
losses_u = utils.AverageMeter('Loss_u', ':.4e')
top1_x = utils.AverageMeter('Acc_x@1', ':6.2f')
top5_x = utils.AverageMeter('Acc_x@5', ':6.2f')
top1_u = utils.AverageMeter('Acc_u@1', ':6.2f')
top5_u = utils.AverageMeter('Acc_u@5', ':6.2f')
mask_info = utils.AverageMeter('Mask', ':6.3f')
curr_lr = utils.InstantMeter('LR', '')
progress = utils.ProgressMeter(
len(train_loader_u),
[curr_lr, batch_time, data_time, losses, losses_x, losses_u, top1_x, top5_x, top1_u, top5_u, mask_info],
prefix="Epoch: [{}/{}]\t".format(epoch, args.epochs))
epoch_x = epoch * math.ceil(len(train_loader_u) / len(train_loader_x))
if args.distributed:
print("set epoch={} for labeled sampler".format(epoch_x))
train_loader_x.sampler.set_epoch(epoch_x)
print("set epoch={} for unlabeled sampler".format(epoch))
train_loader_u.sampler.set_epoch(epoch)
train_iter_x = iter(train_loader_x)
# switch to train mode
model.train()
if args.eman:
print("setting the ema model to eval mode")
if hasattr(model, 'module'):
model.module.ema.eval()
else:
model.ema.eval()
end = time.time()
for i, (images_u, targets_u) in enumerate(train_loader_u):
try:
images_x, targets_x = next(train_iter_x)
except Exception:
epoch_x += 1
print("reshuffle train_loader_x at epoch={}".format(epoch_x))
if args.distributed:
print("set epoch={} for labeled sampler".format(epoch_x))
train_loader_x.sampler.set_epoch(epoch_x)
train_iter_x = iter(train_loader_x)
images_x, targets_x = next(train_iter_x)
images_u_w, images_u_s = images_u
# measure data loading time
data_time.update(time.time() - end)
if args.gpu is not None:
images_x = images_x.cuda(args.gpu, non_blocking=True)
images_u_w = images_u_w.cuda(args.gpu, non_blocking=True)
images_u_s = images_u_s.cuda(args.gpu, non_blocking=True)
targets_x = targets_x.cuda(args.gpu, non_blocking=True)
targets_u = targets_u.cuda(args.gpu, non_blocking=True)
# warmup learning rate
if epoch < args.warmup_epoch:
warmup_step = args.warmup_epoch * len(train_loader_u)
curr_step = epoch * len(train_loader_u) + i + 1
lr_schedule.warmup_learning_rate(optimizer, curr_step, warmup_step, args)
curr_lr.update(optimizer.param_groups[0]['lr'])
# model forward
logits_x, logits_u_w, logits_u_s = model(images_x, images_u_w, images_u_s)
# pseudo label
pseudo_label = torch.softmax(logits_u_w.detach_(), dim=-1)
max_probs, pseudo_targets_u = torch.max(pseudo_label, dim=-1)
mask = max_probs.ge(args.threshold).float()
# compute losses
loss_x = F.cross_entropy(logits_x, targets_x, reduction='mean')
loss_u = (F.cross_entropy(logits_u_s, pseudo_targets_u, reduction='none') * mask).mean()
loss = loss_x + args.lambda_u * loss_u
# measure accuracy and record loss
losses.update(loss.item())
losses_x.update(loss_x.item(), images_x.size(0))
losses_u.update(loss_u.item(), images_u_w.size(0))
acc1_x, acc5_x = utils.accuracy(logits_x, targets_x, topk=(1, 5))
top1_x.update(acc1_x[0], logits_x.size(0))
top5_x.update(acc5_x[0], logits_x.size(0))
acc1_u, acc5_u = utils.accuracy(logits_u_w, targets_u, topk=(1, 5))
top1_u.update(acc1_u[0], logits_u_w.size(0))
top5_u.update(acc5_u[0], logits_u_w.size(0))
mask_info.update(mask.mean().item(), mask.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update the ema model
if args.eman:
if hasattr(model, 'module'):
model.module.momentum_update_ema()
else:
model.momentum_update_ema()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)