in main.py [0:0]
def main_worker(gpu, args):
args.rank += gpu
torch.distributed.init_process_group(
backend='nccl', init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
if args.rank == 0:
args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1)
print(' '.join(sys.argv))
print(' '.join(sys.argv), file=stats_file)
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True
model = BarlowTwins(args).cuda(gpu)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
param_weights = []
param_biases = []
for param in model.parameters():
if param.ndim == 1:
param_biases.append(param)
else:
param_weights.append(param)
parameters = [{'params': param_weights}, {'params': param_biases}]
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
optimizer = LARS(parameters, lr=0, weight_decay=args.weight_decay,
weight_decay_filter=True,
lars_adaptation_filter=True)
# automatically resume from checkpoint if it exists
if (args.checkpoint_dir / 'checkpoint.pth').is_file():
ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
map_location='cpu')
start_epoch = ckpt['epoch']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
else:
start_epoch = 0
dataset = torchvision.datasets.ImageFolder(args.data / 'train', Transform())
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
assert args.batch_size % args.world_size == 0
per_device_batch_size = args.batch_size // args.world_size
loader = torch.utils.data.DataLoader(
dataset, batch_size=per_device_batch_size, num_workers=args.workers,
pin_memory=True, sampler=sampler)
start_time = time.time()
scaler = torch.cuda.amp.GradScaler()
for epoch in range(start_epoch, args.epochs):
sampler.set_epoch(epoch)
for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)):
y1 = y1.cuda(gpu, non_blocking=True)
y2 = y2.cuda(gpu, non_blocking=True)
adjust_learning_rate(args, optimizer, loader, step)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
loss = model.forward(y1, y2)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if step % args.print_freq == 0:
if args.rank == 0:
stats = dict(epoch=epoch, step=step,
lr_weights=optimizer.param_groups[0]['lr'],
lr_biases=optimizer.param_groups[1]['lr'],
loss=loss.item(),
time=int(time.time() - start_time))
print(json.dumps(stats))
print(json.dumps(stats), file=stats_file)
if args.rank == 0:
# save checkpoint
state = dict(epoch=epoch + 1, model=model.state_dict(),
optimizer=optimizer.state_dict())
torch.save(state, args.checkpoint_dir / 'checkpoint.pth')
if args.rank == 0:
# save final model
torch.save(model.module.backbone.state_dict(),
args.checkpoint_dir / 'resnet50.pth')