in main_stica.py [0:0]
def train(train_loader, model, optimizer, epoch, lr_schedule, writer):
# Put model in train mode
model.train()
XE = torch.nn.CrossEntropyLoss()
# Init Logger meters
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
croplosses_meter = AverageMeter()
avlosses = AverageMeter()
end = time.time()
for it, inputs in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# update learning rate
iteration = epoch * len(train_loader) + it
for param_group in optimizer.param_groups:
param_group["lr"] = lr_schedule[iteration]
# get inputs
video, audio, _, _, _ = inputs
audio = audio.cuda(non_blocking=True)
feat_v_nce_lst = []
crop_feat_v_nces_lst = []
feat_a_nce_lst = []
# FORWARD PASSES
for i in range(len(video)):
# get video
video_input = torch.cat(video[i: i+1]).cuda(non_blocking=True)
# get crop params
params = fmcrop_params(
duration=model.module.duration,
s_large_crops=args.num_large_crops,
s_small_crops=args.num_small_crops,
t_large_crops=args.num_large_tcrops,
t_small_crops=args.num_small_tcrops
)
# Forward pass
feat_v_nce, crop_feat_v_nces, feat_a_nce = model(
video_input, audio, params=params)
# Save features
feat_v_nce_lst.append(feat_v_nce)
crop_feat_v_nces_lst.append(crop_feat_v_nces)
feat_a_nce_lst.append(feat_a_nce)
# CROP & LOSS COMPUTATION
crop_losses, counters = nce_crop_losses_dual(
feats_v=crop_feat_v_nces_lst[0],
feats_v2=crop_feat_v_nces_lst[1],
XE=XE,
s_large_crops=args.num_large_crops,
s_small_crops=args.num_small_crops,
t_large_crops=args.num_large_tcrops,
t_small_crops=args.num_small_tcrops,
temp=args.temp
)
loss_crops = sum(crop_losses) / sum(counters)
if args.cross_modal_alpha > 0:
loss_av = 0.5 * (
gdt_loss(feat_v_nce_lst[0], feat_a_nce_lst[0], XE) +
gdt_loss(feat_v_nce_lst[1], feat_a_nce_lst[1], XE)
)
else:
loss_av = torch.tensor(0)
if args.cross_modal_alpha > 0:
loss = (
(1. - args.cross_modal_alpha) * loss_crops +
args.cross_modal_alpha * loss_av
)
else:
loss = (1. - args.cross_modal_alpha) * loss_crops
# BACKWARD AND OPTIMIZER STEP
optimizer.zero_grad()
if args.use_fp16:
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
# LOGGING
bs = audio.size(0)
losses.update(loss.item(), bs)
avlosses.update(loss_av.item(), bs)
croplosses_meter.update(loss_crops.item(), bs)
batch_time.update(time.time() - end)
end = time.time()
if args.rank == 0 and it % 50 == 0:
logger.info(
"Epoch: [{0}][{1}]\t"
"Time {batch_time.val:.2f} ({batch_time.avg:.2f})\t"
"Data {data_time.val:.2f} ({data_time.avg:.2f})\t"
"Loss {loss.val:.2f} ({loss.avg:.2f})\t"
"AVLoss {avloss.val:.2f} ({avloss.avg:.2f})\t"
"CropLoss {closs.val:.2f} ({closs.avg:.2f})\t"
"Lr: {lr:.4f}".format(
epoch,
it,
batch_time=batch_time,
data_time=data_time,
loss=losses,
avloss=avlosses,
closs=croplosses_meter,
lr=optimizer.param_groups[0]["lr"],
)
)
# Log onto tensorboard
if writer:
log_scalars(writer, loss, loss_crops, crop_losses, loss_av,
counters, optimizer, batch_time, data_time, iteration)
# ============ signal handling ... ============
if os.environ['SIGNAL_RECEIVED'] == 'True':
if args.rank == 0:
logger.info("Beginning reqeue")
trigger_job_requeue(os.path.join(
args.dump_path,
"checkpoint.pth.tar"
))
return epoch, losses.avg