in references/classification/train_quantization.py [0:0]
def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
torch.backends.quantized.engine = args.backend
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
# Data loading code
print("Loading data")
train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
)
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
if not args.weights:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else:
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)
if not (args.test_only or args.post_training_quantize):
model.fuse_model()
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
criterion = nn.CrossEntropyLoss()
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.post_training_quantize:
# perform calibration on a subset of the training dataset
# for that, create a subset of the training dataset
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
data_loader_calibration = torch.utils.data.DataLoader(
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model()
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.ao.quantization.convert(model, inplace=True)
if args.output_dir:
print("Saving quantized model")
if utils.is_main_process():
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
print("Evaluating post-training quantized model")
evaluate(model, criterion, data_loader_test, device=device)
return
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return
model.apply(torch.ao.quantization.enable_observer)
model.apply(torch.ao.quantization.enable_fake_quant)
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
lr_scheduler.step()
with torch.inference_mode():
if epoch >= args.num_observer_update_epochs:
print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.ao.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs:
print("Freezing BN for subseq epochs, epoch = ", epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print("Evaluate QAT model")
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
model.train()
if args.output_dir:
checkpoint = {
"model": model_without_ddp.state_dict(),
"eval_model": quantized_eval_model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
print("Saving models after epoch ", epoch)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")