in tinynn/util/cifar10.py [0:0]
def train_one_epoch_distill(model, context: DLContext):
"""Train the model for one epoch with distilling
Args:
model: Student model
context (DLContext): The context object
"""
def _calc_loss(label, label_teacher):
if isinstance(context.criterion, nn.BCEWithLogitsLoss):
label.unsqueeze_(1)
label_onehot = torch.FloatTensor(label.shape[0], 10)
label_onehot.zero_()
label_onehot.scatter_(1, label, 1)
label.squeeze_(1)
label_onehot = label_onehot.to(device=context.device)
label = label.to(device=context.device)
origin_loss = context.criterion(output, label_onehot)
else:
label = label.to(device=context.device)
origin_loss = context.criterion(output, label)
distill_loss = (
F.kl_div(F.log_softmax(output / T, dim=1), F.softmax(label_teacher / T, dim=1), reduction='batchmean')
* T
* T
)
avg_origin_losses.update(origin_loss * (1 - A))
loss = origin_loss * (1 - A) + distill_loss * A
return loss, label
A = context.custom_args['distill_A']
T = context.custom_args['distill_T']
teacher = context.custom_args['distill_teacher']
avg_batch_time = AverageMeter()
avg_data_time = AverageMeter()
avg_losses = AverageMeter()
avg_origin_losses = AverageMeter()
avg_acc = AverageMeter()
model.to(device=context.device)
model.train()
teacher.to(device=context.device)
teacher.eval()
epoch_start = time.time()
batch_end = time.time()
for i, (image, label) in enumerate(context.train_loader):
if context.max_iteration is not None and context.iteration >= context.max_iteration:
break
avg_data_time.update(time.time() - batch_end)
image = image.to(device=context.device)
if context.grad_scaler:
with autocast():
output = model(image)
with torch.no_grad():
label_teacher = teacher(image)
loss, label = _calc_loss(label, label_teacher)
context.grad_scaler.scale(loss).backward()
context.grad_scaler.step(context.optimizer)
context.grad_scaler.update()
else:
output = model(image)
with torch.no_grad():
label_teacher = teacher(image)
loss, label = _calc_loss(label, label_teacher)
loss.backward()
context.optimizer.step()
avg_losses.update(loss.item(), image.size(0))
avg_acc.update(compute_accuracy(output, label), image.size(0))
avg_batch_time.update(time.time() - batch_end)
batch_end = time.time()
if i > 0 and i % context.print_freq == 0:
current_lr = 0.0
for param_group in context.optimizer.param_groups:
current_lr = param_group['lr']
break
print(
f'Epoch:{context.epoch}\t'
f'Iter:[{i}|{len(context.train_loader)}]\t'
f'Lr:{current_lr:.8f}\t'
f'Time:{avg_batch_time.val:.2f}|{time.time() - epoch_start:.2f}\t'
f'Loss:{avg_origin_losses.val:.5f}|{avg_losses.val - avg_origin_losses.val:.5f}\t'
f'Accuracy:{avg_acc.val:.3f}'
)
if context.warmup_scheduler is not None and context.warmup_iteration > context.iteration:
context.warmup_scheduler.step()
context.iteration += 1
if context.scheduler and context.warmup_iteration <= context.iteration:
context.scheduler.step()