tinynn/util/cifar10.py (242 lines of code) (raw):
import time
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.cuda.amp import autocast
from tinynn.util.train_util import AverageMeter, DLContext
def get_dataloader(
data_path: str,
img_size: int = 224,
batch_size: int = 128,
worker: int = 4,
distributed: bool = False,
download: bool = False,
mean: tuple = (0.4914, 0.4822, 0.4465),
std: tuple = (0.2023, 0.1994, 0.2010),
) -> typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
""" Constructs the dataloaders for training and validating
Args:
data_path (str): The path of the dataset
img_size (int, optional): The size of the image. Defaults to 224.
batch_size (int, optional): The batch size of the dataloader. Defaults to 128.
worker (int, optional): The number of workers. Defaults to 4.
distributed (bool, optional): Whether to use DDP. Defaults to False.
download (bool, optional): Whether to download the dataset. Defaults to False.
mean (tuple, optional): Normalize mean
std (tuple, optional): Normalize std
Returns:
typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: The dataloaders for training and \
validating
"""
train_dataset = torchvision.datasets.CIFAR10(
root=data_path,
train=True,
download=download,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.Resize(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
),
)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=worker,
pin_memory=True,
)
val_dataset = torchvision.datasets.CIFAR10(
root=data_path,
train=False,
download=False,
transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize(mean, std)]
),
)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True
)
return train_loader, val_loader
def compute_accuracy(output, target):
output = output.argmax(dim=1)
acc = torch.sum(target == output).item()
acc = acc / output.size(0) * 100
return acc
def train_one_epoch(model, context: DLContext):
"""Train the model for one epoch
Args:
model: The model to be trained
context (DLContext): The context object
"""
def _calc_loss(label):
if isinstance(context.criterion, nn.BCEWithLogitsLoss):
label.unsqueeze_(1)
label_onehot = torch.FloatTensor(label.shape[0], 10)
label_onehot.zero_()
label_onehot.scatter_(1, label, 1)
label.squeeze_(1)
label_onehot = label_onehot.to(device=context.device)
label = label.to(device=context.device)
loss = context.criterion(output, label_onehot)
else:
label = label.to(device=context.device)
loss = context.criterion(output, label)
return loss, label
avg_batch_time = AverageMeter()
avg_data_time = AverageMeter()
avg_losses = AverageMeter()
avg_acc = AverageMeter()
model.to(device=context.device)
model.train()
epoch_start = time.time()
batch_end = time.time()
for i, (image, label) in enumerate(context.train_loader):
if context.max_iteration is not None and context.iteration >= context.max_iteration:
break
avg_data_time.update(time.time() - batch_end)
image = image.to(device=context.device)
context.optimizer.zero_grad()
if context.grad_scaler:
with autocast():
output = model(image)
loss, label = _calc_loss(label)
context.grad_scaler.scale(loss).backward()
context.grad_scaler.step(context.optimizer)
context.grad_scaler.update()
else:
output = model(image)
loss, label = _calc_loss(label)
loss.backward()
context.optimizer.step()
avg_losses.update(loss.item(), image.size(0))
avg_batch_time.update(time.time() - batch_end)
avg_acc.update(compute_accuracy(output, label), image.size(0))
batch_end = time.time()
if i > 0 and i % context.print_freq == 0:
current_lr = 0.0
for param_group in context.optimizer.param_groups:
current_lr = param_group['lr']
break
print(
f'Epoch:{context.epoch}\t'
f'Iter:[{i}|{len(context.train_loader)}]\t'
f'Lr:{current_lr:.8f}\t'
f'Time:{avg_batch_time.val:.2f}|{time.time() - epoch_start:.2f}\t'
f'Loss:{avg_losses.val:.5f}\t'
f'Accuracy:{avg_acc.val:.3f}'
)
if context.warmup_scheduler is not None and context.warmup_iteration > context.iteration:
context.warmup_scheduler.step()
context.iteration += 1
# schedule per iteration
if context.iter_scheduler and context.warmup_iteration <= context.iteration:
context.iter_scheduler.step()
# schedule per epoch
if context.scheduler and context.warmup_iteration <= context.iteration:
context.scheduler.step()
def train_one_epoch_distill(model, context: DLContext):
"""Train the model for one epoch with distilling
Args:
model: Student model
context (DLContext): The context object
"""
def _calc_loss(label, label_teacher):
if isinstance(context.criterion, nn.BCEWithLogitsLoss):
label.unsqueeze_(1)
label_onehot = torch.FloatTensor(label.shape[0], 10)
label_onehot.zero_()
label_onehot.scatter_(1, label, 1)
label.squeeze_(1)
label_onehot = label_onehot.to(device=context.device)
label = label.to(device=context.device)
origin_loss = context.criterion(output, label_onehot)
else:
label = label.to(device=context.device)
origin_loss = context.criterion(output, label)
distill_loss = (
F.kl_div(F.log_softmax(output / T, dim=1), F.softmax(label_teacher / T, dim=1), reduction='batchmean')
* T
* T
)
avg_origin_losses.update(origin_loss * (1 - A))
loss = origin_loss * (1 - A) + distill_loss * A
return loss, label
A = context.custom_args['distill_A']
T = context.custom_args['distill_T']
teacher = context.custom_args['distill_teacher']
avg_batch_time = AverageMeter()
avg_data_time = AverageMeter()
avg_losses = AverageMeter()
avg_origin_losses = AverageMeter()
avg_acc = AverageMeter()
model.to(device=context.device)
model.train()
teacher.to(device=context.device)
teacher.eval()
epoch_start = time.time()
batch_end = time.time()
for i, (image, label) in enumerate(context.train_loader):
if context.max_iteration is not None and context.iteration >= context.max_iteration:
break
avg_data_time.update(time.time() - batch_end)
image = image.to(device=context.device)
if context.grad_scaler:
with autocast():
output = model(image)
with torch.no_grad():
label_teacher = teacher(image)
loss, label = _calc_loss(label, label_teacher)
context.grad_scaler.scale(loss).backward()
context.grad_scaler.step(context.optimizer)
context.grad_scaler.update()
else:
output = model(image)
with torch.no_grad():
label_teacher = teacher(image)
loss, label = _calc_loss(label, label_teacher)
loss.backward()
context.optimizer.step()
avg_losses.update(loss.item(), image.size(0))
avg_acc.update(compute_accuracy(output, label), image.size(0))
avg_batch_time.update(time.time() - batch_end)
batch_end = time.time()
if i > 0 and i % context.print_freq == 0:
current_lr = 0.0
for param_group in context.optimizer.param_groups:
current_lr = param_group['lr']
break
print(
f'Epoch:{context.epoch}\t'
f'Iter:[{i}|{len(context.train_loader)}]\t'
f'Lr:{current_lr:.8f}\t'
f'Time:{avg_batch_time.val:.2f}|{time.time() - epoch_start:.2f}\t'
f'Loss:{avg_origin_losses.val:.5f}|{avg_losses.val - avg_origin_losses.val:.5f}\t'
f'Accuracy:{avg_acc.val:.3f}'
)
if context.warmup_scheduler is not None and context.warmup_iteration > context.iteration:
context.warmup_scheduler.step()
context.iteration += 1
if context.scheduler and context.warmup_iteration <= context.iteration:
context.scheduler.step()
def validate(model, context: DLContext) -> float:
"""Retrieves the accuracy the model via validation
Args:
model: The model to be validated
context (DLContext): The context object
Returns:
float: Accuracy of the model
"""
model.to(device=context.device)
model.eval()
avg_batch_time = AverageMeter()
avg_acc = AverageMeter()
with torch.no_grad():
end = time.time()
for i, (image, label) in enumerate(context.val_loader):
image = image.to(device=context.device)
label = label.to(device=context.device)
output = model(image)
avg_acc.update(compute_accuracy(output, label), image.size(0))
# measure elapsed time
avg_batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
print(
f'Test: [{i}/{len(context.val_loader)}]\tTime {avg_batch_time.avg:.5f}\tAcc@1 {avg_acc.avg:.5f}\t'
)
print(f'Validation Acc@1 {avg_acc.avg:.3f}')
return avg_acc.avg
def calibrate(model, context: DLContext):
"""Calibrates the fake-quantized model
Args:
model: The model to be validated
context (DLContext): The context object
"""
model.to(device=context.device)
model.eval()
avg_batch_time = AverageMeter()
with torch.no_grad():
end = time.time()
for i, (image, _) in enumerate(context.train_loader):
if context.max_iteration is not None and i >= context.max_iteration:
break
image = image.to(device=context.device)
model(image)
# measure elapsed time
avg_batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
print(f'Calibrate: [{i}/{len(context.train_loader)}]\tTime {avg_batch_time.avg:.5f}\t')
context.iteration += 1