tinynn/util/cifar10.py (242 lines of code) (raw):

import time import typing import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch.cuda.amp import autocast from tinynn.util.train_util import AverageMeter, DLContext def get_dataloader( data_path: str, img_size: int = 224, batch_size: int = 128, worker: int = 4, distributed: bool = False, download: bool = False, mean: tuple = (0.4914, 0.4822, 0.4465), std: tuple = (0.2023, 0.1994, 0.2010), ) -> typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """ Constructs the dataloaders for training and validating Args: data_path (str): The path of the dataset img_size (int, optional): The size of the image. Defaults to 224. batch_size (int, optional): The batch size of the dataloader. Defaults to 128. worker (int, optional): The number of workers. Defaults to 4. distributed (bool, optional): Whether to use DDP. Defaults to False. download (bool, optional): Whether to download the dataset. Defaults to False. mean (tuple, optional): Normalize mean std (tuple, optional): Normalize std Returns: typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: The dataloaders for training and \ validating """ train_dataset = torchvision.datasets.CIFAR10( root=data_path, train=True, download=download, transform=transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.Resize(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ] ), ) if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=worker, pin_memory=True, ) val_dataset = torchvision.datasets.CIFAR10( root=data_path, train=False, download=False, transform=transforms.Compose( [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize(mean, std)] ), ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True ) return train_loader, val_loader def compute_accuracy(output, target): output = output.argmax(dim=1) acc = torch.sum(target == output).item() acc = acc / output.size(0) * 100 return acc def train_one_epoch(model, context: DLContext): """Train the model for one epoch Args: model: The model to be trained context (DLContext): The context object """ def _calc_loss(label): if isinstance(context.criterion, nn.BCEWithLogitsLoss): label.unsqueeze_(1) label_onehot = torch.FloatTensor(label.shape[0], 10) label_onehot.zero_() label_onehot.scatter_(1, label, 1) label.squeeze_(1) label_onehot = label_onehot.to(device=context.device) label = label.to(device=context.device) loss = context.criterion(output, label_onehot) else: label = label.to(device=context.device) loss = context.criterion(output, label) return loss, label avg_batch_time = AverageMeter() avg_data_time = AverageMeter() avg_losses = AverageMeter() avg_acc = AverageMeter() model.to(device=context.device) model.train() epoch_start = time.time() batch_end = time.time() for i, (image, label) in enumerate(context.train_loader): if context.max_iteration is not None and context.iteration >= context.max_iteration: break avg_data_time.update(time.time() - batch_end) image = image.to(device=context.device) context.optimizer.zero_grad() if context.grad_scaler: with autocast(): output = model(image) loss, label = _calc_loss(label) context.grad_scaler.scale(loss).backward() context.grad_scaler.step(context.optimizer) context.grad_scaler.update() else: output = model(image) loss, label = _calc_loss(label) loss.backward() context.optimizer.step() avg_losses.update(loss.item(), image.size(0)) avg_batch_time.update(time.time() - batch_end) avg_acc.update(compute_accuracy(output, label), image.size(0)) batch_end = time.time() if i > 0 and i % context.print_freq == 0: current_lr = 0.0 for param_group in context.optimizer.param_groups: current_lr = param_group['lr'] break print( f'Epoch:{context.epoch}\t' f'Iter:[{i}|{len(context.train_loader)}]\t' f'Lr:{current_lr:.8f}\t' f'Time:{avg_batch_time.val:.2f}|{time.time() - epoch_start:.2f}\t' f'Loss:{avg_losses.val:.5f}\t' f'Accuracy:{avg_acc.val:.3f}' ) if context.warmup_scheduler is not None and context.warmup_iteration > context.iteration: context.warmup_scheduler.step() context.iteration += 1 # schedule per iteration if context.iter_scheduler and context.warmup_iteration <= context.iteration: context.iter_scheduler.step() # schedule per epoch if context.scheduler and context.warmup_iteration <= context.iteration: context.scheduler.step() def train_one_epoch_distill(model, context: DLContext): """Train the model for one epoch with distilling Args: model: Student model context (DLContext): The context object """ def _calc_loss(label, label_teacher): if isinstance(context.criterion, nn.BCEWithLogitsLoss): label.unsqueeze_(1) label_onehot = torch.FloatTensor(label.shape[0], 10) label_onehot.zero_() label_onehot.scatter_(1, label, 1) label.squeeze_(1) label_onehot = label_onehot.to(device=context.device) label = label.to(device=context.device) origin_loss = context.criterion(output, label_onehot) else: label = label.to(device=context.device) origin_loss = context.criterion(output, label) distill_loss = ( F.kl_div(F.log_softmax(output / T, dim=1), F.softmax(label_teacher / T, dim=1), reduction='batchmean') * T * T ) avg_origin_losses.update(origin_loss * (1 - A)) loss = origin_loss * (1 - A) + distill_loss * A return loss, label A = context.custom_args['distill_A'] T = context.custom_args['distill_T'] teacher = context.custom_args['distill_teacher'] avg_batch_time = AverageMeter() avg_data_time = AverageMeter() avg_losses = AverageMeter() avg_origin_losses = AverageMeter() avg_acc = AverageMeter() model.to(device=context.device) model.train() teacher.to(device=context.device) teacher.eval() epoch_start = time.time() batch_end = time.time() for i, (image, label) in enumerate(context.train_loader): if context.max_iteration is not None and context.iteration >= context.max_iteration: break avg_data_time.update(time.time() - batch_end) image = image.to(device=context.device) if context.grad_scaler: with autocast(): output = model(image) with torch.no_grad(): label_teacher = teacher(image) loss, label = _calc_loss(label, label_teacher) context.grad_scaler.scale(loss).backward() context.grad_scaler.step(context.optimizer) context.grad_scaler.update() else: output = model(image) with torch.no_grad(): label_teacher = teacher(image) loss, label = _calc_loss(label, label_teacher) loss.backward() context.optimizer.step() avg_losses.update(loss.item(), image.size(0)) avg_acc.update(compute_accuracy(output, label), image.size(0)) avg_batch_time.update(time.time() - batch_end) batch_end = time.time() if i > 0 and i % context.print_freq == 0: current_lr = 0.0 for param_group in context.optimizer.param_groups: current_lr = param_group['lr'] break print( f'Epoch:{context.epoch}\t' f'Iter:[{i}|{len(context.train_loader)}]\t' f'Lr:{current_lr:.8f}\t' f'Time:{avg_batch_time.val:.2f}|{time.time() - epoch_start:.2f}\t' f'Loss:{avg_origin_losses.val:.5f}|{avg_losses.val - avg_origin_losses.val:.5f}\t' f'Accuracy:{avg_acc.val:.3f}' ) if context.warmup_scheduler is not None and context.warmup_iteration > context.iteration: context.warmup_scheduler.step() context.iteration += 1 if context.scheduler and context.warmup_iteration <= context.iteration: context.scheduler.step() def validate(model, context: DLContext) -> float: """Retrieves the accuracy the model via validation Args: model: The model to be validated context (DLContext): The context object Returns: float: Accuracy of the model """ model.to(device=context.device) model.eval() avg_batch_time = AverageMeter() avg_acc = AverageMeter() with torch.no_grad(): end = time.time() for i, (image, label) in enumerate(context.val_loader): image = image.to(device=context.device) label = label.to(device=context.device) output = model(image) avg_acc.update(compute_accuracy(output, label), image.size(0)) # measure elapsed time avg_batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: print( f'Test: [{i}/{len(context.val_loader)}]\tTime {avg_batch_time.avg:.5f}\tAcc@1 {avg_acc.avg:.5f}\t' ) print(f'Validation Acc@1 {avg_acc.avg:.3f}') return avg_acc.avg def calibrate(model, context: DLContext): """Calibrates the fake-quantized model Args: model: The model to be validated context (DLContext): The context object """ model.to(device=context.device) model.eval() avg_batch_time = AverageMeter() with torch.no_grad(): end = time.time() for i, (image, _) in enumerate(context.train_loader): if context.max_iteration is not None and i >= context.max_iteration: break image = image.to(device=context.device) model(image) # measure elapsed time avg_batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: print(f'Calibrate: [{i}/{len(context.train_loader)}]\tTime {avg_batch_time.avg:.5f}\t') context.iteration += 1