in main_swav.py [0:0]
def train(train_loader, model, optimizer, epoch, lr_schedule, queue):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
model.train()
use_the_queue = False
end = time.time()
for it, inputs in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# update learning rate
iteration = epoch * len(train_loader) + it
for param_group in optimizer.param_groups:
param_group["lr"] = lr_schedule[iteration]
# normalize the prototypes
with torch.no_grad():
w = model.module.prototypes.weight.data.clone()
w = nn.functional.normalize(w, dim=1, p=2)
model.module.prototypes.weight.copy_(w)
# ============ multi-res forward passes ... ============
embedding, output = model(inputs)
embedding = embedding.detach()
bs = inputs[0].size(0)
# ============ swav loss ... ============
loss = 0
for i, crop_id in enumerate(args.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id: bs * (crop_id + 1)].detach()
# time to use the queue
if queue is not None:
if use_the_queue or not torch.all(queue[i, -1, :] == 0):
use_the_queue = True
out = torch.cat((torch.mm(
queue[i],
model.module.prototypes.weight.t()
), out))
# fill the queue
queue[i, bs:] = queue[i, :-bs].clone()
queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs]
# get assignments
q = distributed_sinkhorn(out)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id):
x = output[bs * v: bs * (v + 1)] / args.temperature
subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(args.nmb_crops) - 1)
loss /= len(args.crops_for_assign)
# ============ backward and optim step ... ============
optimizer.zero_grad()
if args.use_fp16:
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# cancel gradients for the prototypes
if iteration < args.freeze_prototypes_niters:
for name, p in model.named_parameters():
if "prototypes" in name:
p.grad = None
optimizer.step()
# ============ misc ... ============
losses.update(loss.item(), inputs[0].size(0))
batch_time.update(time.time() - end)
end = time.time()
if args.rank ==0 and it % 50 == 0:
logger.info(
"Epoch: [{0}][{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
"Lr: {lr:.4f}".format(
epoch,
it,
batch_time=batch_time,
data_time=data_time,
loss=losses,
lr=optimizer.optim.param_groups[0]["lr"],
)
)
return (epoch, losses.avg), queue