in pycls/core/benchmark.py [0:0]
def compute_time_train(model, loss_fun):
"""Computes precise model forward + backward time using dummy data."""
# Use train mode
model.train()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
labels_one_hot = net.smooth_one_hot_labels(labels)
# Cache BatchNorm2D running stats
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
# Create a GradScaler for mixed precision training
scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
# Compute precise forward backward pass time
fw_timer, bw_timer = Timer(), Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
fw_timer.reset()
bw_timer.reset()
# Forward
fw_timer.tic()
with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
preds = model(inputs)
loss = loss_fun(preds, labels_one_hot)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
scaler.scale(loss).backward()
torch.cuda.synchronize()
bw_timer.toc()
# Restore BatchNorm2D running stats
for bn, (mean, var) in zip(bns, bn_stats):
bn.running_mean, bn.running_var = mean, var
return fw_timer.average_time, bw_timer.average_time