in main.py [0:0]
def train(train_loader, model, optimizer, epoch, writer, selflabels):
global sk_schedule
global sk_counter
# Put model in train mode
model.train()
# Init Logger meters
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
world_size = args.world_size
dataset_bs = train_loader.batch_size
end = time.time()
batches_thusfar = epoch * len(train_loader)
for it, inputs in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# ============ Get inputs ... ============
video, audio, _, selected, _ = inputs
video, audio = video.cuda(), audio.cuda()
# ============ Occasional clustering via Sinkhorn-Knopp ... ===========
if batches_thusfar + it >= sk_schedule[-1]:
# optimize labels
with torch.no_grad():
_ = sk_schedule.pop()
selflabels = cluster(
args, selflabels, train_loader.dataset, model, sk_counter,
logger, writer, group,
(batches_thusfar + it) * dataset_bs * world_size
)
# ============ forward passes ... ============
feat_v, feat_a = model(video, audio)
# ============ SeLaVi loss ... ============
if args.headcount == 1:
labels = selflabels[selected, 0]
else:
labels = selflabels[selected, :]
loss_vid = get_loss(feat_v, labels, headcount=args.headcount)
loss_aud = get_loss(feat_a, labels, headcount=args.headcount)
loss = 0.5 * loss_vid + 0.5 * loss_aud
# ============ backward and optim step ... ============
optimizer.zero_grad()
if args.use_fp16:
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
# ============ misc ... ============
losses.update(loss.item(), inputs[0].size(0))
batch_time.update(time.time() - end)
end = time.time()
iteration = epoch * len(train_loader) + it
if args.rank == 0 and it % 50 == 0:
logger.info(
"Epoch: [{0}][{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
"Lr: {lr:.4f}".format(
epoch,
it,
batch_time=batch_time,
data_time=data_time,
loss=losses,
lr=optimizer.param_groups[0]["lr"],
)
)
# Log onto tensorboard
if writer:
writer.add_scalar(
f'loss/iter', loss.item(), iteration)
writer.add_scalar(
f'lr/iter', optimizer.param_groups[0]["lr"], iteration)
writer.add_scalar(
f'batch_time/iter', batch_time.avg, iteration)
writer.add_scalar(
f'data_time/iter', data_time.avg, iteration)
# ============ signal handling ... ============
if os.environ['SIGNAL_RECEIVED'] == 'True':
if args.rank == 0:
logger.info("Beginning reqeue")
trigger_job_requeue(
os.path.join(args.dump_path, "checkpoint.pth.tar"))
dist.barrier()
torch.cuda.empty_cache()
return (epoch, losses.avg), selflabels