in main-avid.py [0:0]
def run_phase(phase, loader, model, optimizer, criterion, epoch, args, cfg, logger, tb_writter):
from utils import metrics_utils
logger.add_line('\n{}: Epoch {}'.format(phase, epoch))
batch_time = metrics_utils.AverageMeter('Time', ':6.3f', window_size=100)
data_time = metrics_utils.AverageMeter('Data', ':6.3f', window_size=100)
loss_meter = metrics_utils.AverageMeter('Loss', ':.3e')
progress = utils.logger.ProgressMeter(len(loader), [batch_time, data_time, loss_meter],
phase=phase, epoch=epoch, logger=logger, tb_writter=tb_writter)
# switch to train mode
model.train(phase == 'train')
end = time.time()
device = args.gpu if args.gpu is not None else 0
for i, sample in enumerate(loader):
# measure data loading time
data_time.update(time.time() - end)
# Prepare batch
video, audio, index = sample['frames'], sample['audio'], sample['index']
video = video.cuda(device, non_blocking=True)
audio = audio.cuda(device, non_blocking=True)
index = index.cuda(device, non_blocking=True)
# compute audio and video embeddings
if phase == 'train':
video_emb, audio_emb = model(video, audio)
else:
with torch.no_grad():
video_emb, audio_emb = model(video, audio)
# compute loss
loss, loss_debug = criterion(video_emb, audio_emb, index)
loss_meter.update(loss.item(), video.size(0))
# compute gradient and do SGD step during training
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# print to terminal and tensorboard
step = epoch * len(loader) + i
if (i+1) % cfg['print_freq'] == 0 or i == 0 or i+1 == len(loader):
progress.display(i+1)
if tb_writter is not None:
for key in loss_debug:
tb_writter.add_scalar('{}-batch/{}'.format(phase, key), loss_debug[key].item(), step)
# Sync metrics across all GPUs and print final averages
if args.distributed:
progress.synchronize_meters(args.gpu)
progress.display(len(loader)*args.world_size)
if tb_writter is not None:
for meter in progress.meters:
tb_writter.add_scalar('{}-epoch/{}'.format(phase, meter.name), meter.avg, epoch)