def train_one_epoch_distill()

in tinynn/util/cifar10.py [0:0]


def train_one_epoch_distill(model, context: DLContext):
    """Train the model for one epoch with distilling

    Args:
        model: Student model
        context (DLContext): The context object
    """

    def _calc_loss(label, label_teacher):
        if isinstance(context.criterion, nn.BCEWithLogitsLoss):
            label.unsqueeze_(1)
            label_onehot = torch.FloatTensor(label.shape[0], 10)
            label_onehot.zero_()
            label_onehot.scatter_(1, label, 1)
            label.squeeze_(1)
            label_onehot = label_onehot.to(device=context.device)
            label = label.to(device=context.device)
            origin_loss = context.criterion(output, label_onehot)
        else:
            label = label.to(device=context.device)
            origin_loss = context.criterion(output, label)

        distill_loss = (
            F.kl_div(F.log_softmax(output / T, dim=1), F.softmax(label_teacher / T, dim=1), reduction='batchmean')
            * T
            * T
        )

        avg_origin_losses.update(origin_loss * (1 - A))
        loss = origin_loss * (1 - A) + distill_loss * A

        return loss, label

    A = context.custom_args['distill_A']
    T = context.custom_args['distill_T']
    teacher = context.custom_args['distill_teacher']

    avg_batch_time = AverageMeter()
    avg_data_time = AverageMeter()
    avg_losses = AverageMeter()
    avg_origin_losses = AverageMeter()
    avg_acc = AverageMeter()

    model.to(device=context.device)
    model.train()

    teacher.to(device=context.device)
    teacher.eval()

    epoch_start = time.time()
    batch_end = time.time()
    for i, (image, label) in enumerate(context.train_loader):

        if context.max_iteration is not None and context.iteration >= context.max_iteration:
            break

        avg_data_time.update(time.time() - batch_end)
        image = image.to(device=context.device)

        if context.grad_scaler:
            with autocast():
                output = model(image)
                with torch.no_grad():
                    label_teacher = teacher(image)
                loss, label = _calc_loss(label, label_teacher)

                context.grad_scaler.scale(loss).backward()
                context.grad_scaler.step(context.optimizer)
                context.grad_scaler.update()
        else:
            output = model(image)
            with torch.no_grad():
                label_teacher = teacher(image)
            loss, label = _calc_loss(label, label_teacher)
            loss.backward()
            context.optimizer.step()

        avg_losses.update(loss.item(), image.size(0))
        avg_acc.update(compute_accuracy(output, label), image.size(0))
        avg_batch_time.update(time.time() - batch_end)
        batch_end = time.time()

        if i > 0 and i % context.print_freq == 0:
            current_lr = 0.0
            for param_group in context.optimizer.param_groups:
                current_lr = param_group['lr']
                break
            print(
                f'Epoch:{context.epoch}\t'
                f'Iter:[{i}|{len(context.train_loader)}]\t'
                f'Lr:{current_lr:.8f}\t'
                f'Time:{avg_batch_time.val:.2f}|{time.time() - epoch_start:.2f}\t'
                f'Loss:{avg_origin_losses.val:.5f}|{avg_losses.val - avg_origin_losses.val:.5f}\t'
                f'Accuracy:{avg_acc.val:.3f}'
            )

        if context.warmup_scheduler is not None and context.warmup_iteration > context.iteration:
            context.warmup_scheduler.step()

        context.iteration += 1

    if context.scheduler and context.warmup_iteration <= context.iteration:
        context.scheduler.step()