in scripts/train_imagenet.py [0:0]
def main():
global args, best_prec1, logger, conf, tb
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
try:
world_size = int(os.environ["WORLD_SIZE"])
distributed = world_size > 1
except:
distributed = False
world_size = 1
if distributed:
dist.init_process_group(backend=args.dist_backend, init_method="env://")
rank = 0 if not distributed else dist.get_rank()
init_logger(rank, args.log_dir)
tb = SummaryWriter(args.log_dir) if rank == 0 else None
# Load configuration
conf = config.load_config(args.config)
# Create model
model_params = utils.get_model_params(conf["network"])
model = models.__dict__["net_" + conf["network"]["arch"]](**model_params)
model.cuda()
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank
)
else:
model = SingleGPU(model)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer, scheduler = utils.create_optimizer(conf["optimizer"], model)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
logger.info("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint["epoch"]
best_prec1 = checkpoint["best_prec1"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
logger.info(
"=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint["epoch"]
)
)
else:
logger.warning("=> no checkpoint found at '{}'".format(args.resume))
else:
init_weights(model)
args.start_epoch = 0
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, "train")
valdir = os.path.join(args.data, "val")
train_transforms, val_transforms = utils.create_transforms(conf["input"])
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_transforms))
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=conf["optimizer"]["batch_size"] // world_size,
shuffle=(train_sampler is None),
num_workers=args.workers,
pin_memory=True,
sampler=train_sampler,
)
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_transforms))
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=conf["optimizer"]["batch_size"] // world_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
sampler=TestDistributedSampler(val_dataset),
)
if args.evaluate:
utils.validate(
val_loader,
model,
criterion,
print_freq=args.print_freq,
tb=tb,
logger=logger.info,
)
return
for epoch in range(args.start_epoch, conf["optimizer"]["schedule"]["epochs"]):
if distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, scheduler, epoch)
# evaluate on validation set
prec1 = utils.validate(
val_loader,
model,
criterion,
it=epoch * len(train_loader),
print_freq=args.print_freq,
tb=tb,
logger=logger.info,
)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
if rank == 0:
save_checkpoint(
{
"epoch": epoch + 1,
"arch": conf["network"]["arch"],
"state_dict": model.state_dict(),
"best_prec1": best_prec1,
"optimizer": optimizer.state_dict(),
},
is_best,
args.log_dir,
)