in main_swav.py [0:0]
def main():
global args
args = parser.parse_args()
init_distributed_mode(args)
fix_random_seeds(args.seed)
logger, training_stats = initialize_exp(args, "epoch", "loss")
# build data
train_dataset = MultiCropDataset(
args.data_path,
args.size_crops,
args.nmb_crops,
args.min_scale_crops,
args.max_scale_crops,
)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
drop_last=True
)
logger.info("Building data done with {} images loaded.".format(len(train_dataset)))
# build model
model = resnet_models.__dict__[args.arch](
normalize=True,
hidden_mlp=args.hidden_mlp,
output_dim=args.feat_dim,
nmb_prototypes=args.nmb_prototypes,
)
# synchronize batch norm layers
if args.sync_bn == "pytorch":
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
elif args.sync_bn == "apex":
# with apex syncbn we sync bn per group because it speeds up computation
# compared to global syncbn
process_group = apex.parallel.create_syncbn_process_group(args.syncbn_process_group_size)
model = apex.parallel.convert_syncbn_model(model, process_group=process_group)
# copy model to GPU
model = model.cuda()
if args.rank == 0:
logger.info(model)
logger.info("Building model done.")
# build optimizer
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.base_lr,
momentum=0.9,
weight_decay=args.wd,
)
optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False)
warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters])
lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
logger.info("Building optimizer done.")
# init mixed precision
if args.use_fp16:
model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
logger.info("Initializing mixed precision done.")
# wrap model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[args.gpu_to_work_on]
)
# optionally resume from a checkpoint
to_restore = {"epoch": 0}
restart_from_checkpoint(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
run_variables=to_restore,
state_dict=model,
optimizer=optimizer,
amp=apex.amp,
)
start_epoch = to_restore["epoch"]
# build the queue
queue = None
queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth")
if os.path.isfile(queue_path):
queue = torch.load(queue_path)["queue"]
# the queue needs to be divisible by the batch size
args.queue_length -= args.queue_length % (args.batch_size * args.world_size)
cudnn.benchmark = True
for epoch in range(start_epoch, args.epochs):
# train the network for one epoch
logger.info("============ Starting epoch %i ... ============" % epoch)
# set sampler
train_loader.sampler.set_epoch(epoch)
# optionally starts a queue
if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None:
queue = torch.zeros(
len(args.crops_for_assign),
args.queue_length // args.world_size,
args.feat_dim,
).cuda()
# train the network
scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue)
training_stats.update(scores)
# save checkpoints
if args.rank == 0:
save_dict = {
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if args.use_fp16:
save_dict["amp"] = apex.amp.state_dict()
torch.save(
save_dict,
os.path.join(args.dump_path, "checkpoint.pth.tar"),
)
if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
shutil.copyfile(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"),
)
if queue is not None:
torch.save({"queue": queue}, queue_path)