in main_deepclusterv2.py [0:0]
def train(loader, model, optimizer, epoch, schedule, local_memory_index, local_memory_embeddings):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
model.train()
cross_entropy = nn.CrossEntropyLoss(ignore_index=-100)
assignments = cluster_memory(model, local_memory_index, local_memory_embeddings, len(loader.dataset))
logger.info('Clustering for epoch {} done.'.format(epoch))
end = time.time()
start_idx = 0
for it, (idx, inputs) in enumerate(loader):
# measure data loading time
data_time.update(time.time() - end)
# update learning rate
iteration = epoch * len(loader) + it
for param_group in optimizer.param_groups:
param_group["lr"] = schedule[iteration]
# ============ multi-res forward passes ... ============
emb, output = model(inputs)
emb = emb.detach()
bs = inputs[0].size(0)
# ============ deepcluster-v2 loss ... ============
loss = 0
for h in range(len(args.nmb_prototypes)):
scores = output[h] / args.temperature
targets = assignments[h][idx].repeat(sum(args.nmb_crops)).cuda(non_blocking=True)
loss += cross_entropy(scores, targets)
loss /= len(args.nmb_prototypes)
# ============ backward and optim step ... ============
optimizer.zero_grad()
loss.backward()
# cancel some gradients
if iteration < args.freeze_prototypes_niters:
for name, p in model.named_parameters():
if "prototypes" in name:
p.grad = None
optimizer.step()
# ============ update memory banks ... ============
local_memory_index[start_idx : start_idx + bs] = idx
for i, crop_idx in enumerate(args.crops_for_assign):
local_memory_embeddings[i][start_idx : start_idx + bs] = \
emb[crop_idx * bs : (crop_idx + 1) * bs]
start_idx += bs
# ============ misc ... ============
losses.update(loss.item(), inputs[0].size(0))
batch_time.update(time.time() - end)
end = time.time()
if args.rank ==0 and it % 50 == 0:
logger.info(
"Epoch: [{0}][{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
"Lr: {lr:.4f}".format(
epoch,
it,
batch_time=batch_time,
data_time=data_time,
loss=losses,
lr=optimizer.optim.param_groups[0]["lr"],
)
)
return (epoch, losses.avg), local_memory_index, local_memory_embeddings