scripts/train_imagenet.py (306 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import logging
import os
import shutil
import time
import models
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from dataset.sampler import TestDistributedSampler
from imagenet import config, utils
from inplace_abn import ABN
from modules import SingleGPU
from tensorboardX import SummaryWriter
parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
parser.add_argument("config", metavar="CONFIG_FILE", help="path to configuration file")
parser.add_argument("data", metavar="DIR", help="path to dataset")
parser.add_argument(
"-j",
"--workers",
default=2,
type=int,
metavar="N",
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--print-freq",
"-p",
default=10,
type=int,
metavar="N",
help="print frequency (default: 10)",
)
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument("--local_rank", default=0, type=int, help="process rank on node")
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--log-dir",
type=str,
default=".",
metavar="PATH",
help="output directory for Tensorboard log",
)
parser.add_argument(
"--log-hist", action="store_true", help="log histograms of the weights"
)
best_prec1 = 0
args = None
conf = None
tb = None
logger = None
def init_logger(rank, log_dir):
global logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.FileHandler(os.path.join(log_dir, "training_{}.log".format(rank)))
formatter = logging.Formatter("%(asctime)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
if rank == 0:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
def main():
global args, best_prec1, logger, conf, tb
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
try:
world_size = int(os.environ["WORLD_SIZE"])
distributed = world_size > 1
except:
distributed = False
world_size = 1
if distributed:
dist.init_process_group(backend=args.dist_backend, init_method="env://")
rank = 0 if not distributed else dist.get_rank()
init_logger(rank, args.log_dir)
tb = SummaryWriter(args.log_dir) if rank == 0 else None
# Load configuration
conf = config.load_config(args.config)
# Create model
model_params = utils.get_model_params(conf["network"])
model = models.__dict__["net_" + conf["network"]["arch"]](**model_params)
model.cuda()
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank
)
else:
model = SingleGPU(model)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer, scheduler = utils.create_optimizer(conf["optimizer"], model)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
logger.info("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint["epoch"]
best_prec1 = checkpoint["best_prec1"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
logger.info(
"=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint["epoch"]
)
)
else:
logger.warning("=> no checkpoint found at '{}'".format(args.resume))
else:
init_weights(model)
args.start_epoch = 0
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, "train")
valdir = os.path.join(args.data, "val")
train_transforms, val_transforms = utils.create_transforms(conf["input"])
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_transforms))
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=conf["optimizer"]["batch_size"] // world_size,
shuffle=(train_sampler is None),
num_workers=args.workers,
pin_memory=True,
sampler=train_sampler,
)
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_transforms))
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=conf["optimizer"]["batch_size"] // world_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
sampler=TestDistributedSampler(val_dataset),
)
if args.evaluate:
utils.validate(
val_loader,
model,
criterion,
print_freq=args.print_freq,
tb=tb,
logger=logger.info,
)
return
for epoch in range(args.start_epoch, conf["optimizer"]["schedule"]["epochs"]):
if distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, scheduler, epoch)
# evaluate on validation set
prec1 = utils.validate(
val_loader,
model,
criterion,
it=epoch * len(train_loader),
print_freq=args.print_freq,
tb=tb,
logger=logger.info,
)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
if rank == 0:
save_checkpoint(
{
"epoch": epoch + 1,
"arch": conf["network"]["arch"],
"state_dict": model.state_dict(),
"best_prec1": best_prec1,
"optimizer": optimizer.state_dict(),
},
is_best,
args.log_dir,
)
def train(train_loader, model, criterion, optimizer, scheduler, epoch):
global logger, conf, tb
batch_time = utils.AverageMeter()
data_time = utils.AverageMeter()
losses = utils.AverageMeter()
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
if conf["optimizer"]["schedule"]["mode"] == "epoch":
scheduler.step(epoch)
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
if conf["optimizer"]["schedule"]["mode"] == "step":
scheduler.step(i + epoch * len(train_loader))
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
if conf["optimizer"]["clip"] != 0.0:
nn.utils.clip_grad_norm(model.parameters(), conf["optimizer"]["clip"])
optimizer.step()
# measure accuracy and record loss
with torch.no_grad():
output = output.detach()
loss = loss.detach() * target.shape[0]
prec1, prec5 = utils.accuracy_sum(output, target, topk=(1, 5))
count = target.new_tensor([target.shape[0]], dtype=torch.long)
if dist.is_initialized():
dist.all_reduce(count, dist.ReduceOp.SUM)
for meter, val in (losses, loss), (top1, prec1), (top5, prec5):
if dist.is_initialized():
dist.all_reduce(val, dist.ReduceOp.SUM)
val /= count.item()
meter.update(val.item(), count.item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
logger.info(
"Epoch: [{0}][{1}/{2}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t"
"Data {data_time.val:.3f} ({data_time.avg:.3f}) \t"
"Loss {loss.val:.4f} ({loss.avg:.4f}) \t"
"Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t"
"Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
epoch,
i,
len(train_loader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
top1=top1,
top5=top5,
)
)
if not dist.is_initialized() or dist.get_rank() == 0:
tb.add_scalar("train/loss", losses.val, i + epoch * len(train_loader))
tb.add_scalar(
"train/lr", scheduler.get_lr()[0], i + epoch * len(train_loader)
)
tb.add_scalar("train/top1", top1.val, i + epoch * len(train_loader))
tb.add_scalar("train/top5", top5.val, i + epoch * len(train_loader))
if args.log_hist and i % 10 == 0:
for name, param in model.named_parameters():
if name.find("fc") != -1 or name.find("bn_out") != -1:
tb.add_histogram(
name,
param.clone().cpu().data.numpy(),
i + epoch * len(train_loader),
)
def save_checkpoint(state, is_best, log_dir):
filepath = os.path.join(log_dir, "checkpoint.pth.tar")
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(log_dir, "model_best.pth.tar"))
def init_weights(model):
global conf
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
init_fn = getattr(nn.init, conf["network"]["weight_init"] + "_")
if (
conf["network"]["weight_init"].startswith("xavier")
or conf["network"]["weight_init"] == "orthogonal"
):
gain = conf["network"]["weight_gain_multiplier"]
if (
conf["network"]["activation"] == "relu"
or conf["network"]["activation"] == "elu"
):
gain *= nn.init.calculate_gain("relu")
elif conf["network"]["activation"] == "leaky_relu":
gain *= nn.init.calculate_gain(
"leaky_relu", conf["network"]["activation_param"]
)
init_fn(m.weight, gain)
elif conf["network"]["weight_init"].startswith("kaiming"):
if (
conf["network"]["activation"] == "relu"
or conf["network"]["activation"] == "elu"
):
init_fn(m.weight, 0)
else:
init_fn(m.weight, conf["network"]["activation_param"])
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, ABN):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, 0.1)
nn.init.constant_(m.bias, 0.0)
if __name__ == "__main__":
main()