in src/transformers/utils/video_action_recognition.py [0:0]
def test_classification(cfg, data_loader, model, criterion, epoch, writer=None):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
postprocess = torch.nn.Softmax(dim=1)
end = time.time()
for step, data in enumerate(data_loader):
data_time.update(time.time() - end)
samples = data[0].cuda(non_blocking=True)
targets = data[1].cuda(non_blocking=True)
'''
The shape of `samples` is like BxCxTxHxW, where B is 1, C is 3 (RGB).
T = test_num_segment * test_num_crop * clip_len, e.g., T = 10 * 3 * 8 for vidtr_s_8x8.
H and W are just crop size, e.g., 224 or 256 for most of the time.
'''
assert samples.size(0) == 1, 'batch_size during multiview test must be set to 1 due to limited GPU memory'
out_list = []
for i in range(0, samples.size(2), cfg.CONFIG.DATA.CLIP_LEN):
cur_input = samples[:, :, i:i+cfg.CONFIG.DATA.CLIP_LEN, :, :]
out_list.append(postprocess(model(cur_input)))
outputs = torch.cat(out_list, dim=0)
outputs = torch.mean(outputs, dim=0, keepdim=True)
loss = criterion(outputs, targets)
acc1, acc5 = accuracy(outputs.data, targets, topk=(1, 5))
acc1 = reduce_tensor(acc1)
acc5 = reduce_tensor(acc5)
loss = reduce_tensor(loss)
losses.update(loss.item(), targets.size(0))
top1.update(acc1.item(), targets.size(0))
top5.update(acc5.item(), targets.size(0))
batch_time.update(time.time() - end)
end = time.time()
if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0 and step % cfg.CONFIG.LOG.DISPLAY_FREQ == 0:
print('----Testing----')
print_string = 'Epoch: [{0}][{1}/{2}]'.format(epoch, step + 1, len(data_loader))
print(print_string)
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
print_string = 'memory used: {memory:.0f}MB'.format(memory=memory_used)
print(print_string)
print_string = 'data_time: {data_time:.3f}, batch time: {batch_time:.3f}'.format(
data_time=data_time.val,
batch_time=batch_time.val)
print(print_string)
print_string = 'loss: {loss:.5f}'.format(loss=losses.avg)
print(print_string)
print_string = 'Top-1 accuracy: {top1_acc:.2f}%, Top-5 accuracy: {top5_acc:.2f}%'.format(
top1_acc=top1.avg,
top5_acc=top5.avg)
print(print_string)
return top1.avg, top5.avg, losses.avg