in main.py [0:0]
def main(local_rank, args):
if args.ngpus > 1:
print(
"Initializing Distributed Training. This is in BETA mode and hasn't been tested thoroughly. Use at your own risk :)"
)
print("To get the maximum speed-up consider reducing evaluations on val set by setting --eval_every_epoch to greater than 50")
init_distributed(
local_rank,
global_rank=local_rank,
world_size=args.ngpus,
dist_url=args.dist_url,
dist_backend="nccl",
)
print(f"Called with args: {args}")
torch.cuda.set_device(local_rank)
np.random.seed(args.seed + get_rank())
torch.manual_seed(args.seed + get_rank())
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed + get_rank())
datasets, dataset_config = build_dataset(args)
model, _ = build_model(args, dataset_config)
model = model.cuda(local_rank)
model_no_ddp = model
if is_distributed():
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank]
)
criterion = build_criterion(args, dataset_config)
criterion = criterion.cuda(local_rank)
dataloaders = {}
if args.test_only:
dataset_splits = ["test"]
else:
dataset_splits = ["train", "test"]
for split in dataset_splits:
if split == "train":
shuffle = True
else:
shuffle = False
if is_distributed():
sampler = DistributedSampler(datasets[split], shuffle=shuffle)
elif shuffle:
sampler = torch.utils.data.RandomSampler(datasets[split])
else:
sampler = torch.utils.data.SequentialSampler(datasets[split])
dataloaders[split] = DataLoader(
datasets[split],
sampler=sampler,
batch_size=args.batchsize_per_gpu,
num_workers=args.dataset_num_workers,
worker_init_fn=my_worker_init_fn,
)
dataloaders[split + "_sampler"] = sampler
if args.test_only:
criterion = None # faster evaluation
test_model(args, model, model_no_ddp, criterion, dataset_config, dataloaders)
else:
assert (
args.checkpoint_dir is not None
), f"Please specify a checkpoint dir using --checkpoint_dir"
if is_primary() and not os.path.isdir(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir, exist_ok=True)
optimizer = build_optimizer(args, model_no_ddp)
loaded_epoch, best_val_metrics = resume_if_possible(
args.checkpoint_dir, model_no_ddp, optimizer
)
args.start_epoch = loaded_epoch + 1
do_train(
args,
model,
model_no_ddp,
optimizer,
criterion,
dataset_config,
dataloaders,
best_val_metrics,
)