in main_moco.py [0:0]
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
# suppress printing if not first GPU on each node
if args.multiprocessing_distributed and (args.gpu != 0 or args.rank != 0):
def print_pass(*args):
pass
builtins.print = print_pass
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
# create model
print("=> creating model '{}'".format(args.arch))
if args.arch.startswith('vit'):
model = moco.builder.MoCo_ViT(
partial(vits.__dict__[args.arch], stop_grad_conv1=args.stop_grad_conv1),
args.moco_dim, args.moco_mlp_dim, args.moco_t)
else:
model = moco.builder.MoCo_ResNet(
partial(torchvision_models.__dict__[args.arch], zero_init_residual=True),
args.moco_dim, args.moco_mlp_dim, args.moco_t)
# infer learning rate before changing batch size
args.lr = args.lr * args.batch_size / 256
if not torch.cuda.is_available():
print('using CPU, this will be slow')
elif args.distributed:
# apply SyncBN
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / args.world_size)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
# comment out the following line for debugging
raise NotImplementedError("Only DistributedDataParallel is supported.")
else:
# AllGather/rank implementation in this code only supports DistributedDataParallel.
raise NotImplementedError("Only DistributedDataParallel is supported.")
print(model) # print model after SyncBatchNorm
if args.optimizer == 'lars':
optimizer = moco.optimizer.LARS(model.parameters(), args.lr,
weight_decay=args.weight_decay,
momentum=args.momentum)
elif args.optimizer == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), args.lr,
weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler()
summary_writer = SummaryWriter() if args.rank == 0 else None
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scaler.load_state_dict(checkpoint['scaler'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, 'train')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# follow BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733
augmentation1 = [
transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=1.0),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
augmentation2 = [
transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.1),
transforms.RandomApply([moco.loader.Solarize()], p=0.2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
train_dataset = datasets.ImageFolder(
traindir,
moco.loader.TwoCropsTransform(transforms.Compose(augmentation1),
transforms.Compose(augmentation2)))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, optimizer, scaler, summary_writer, epoch, args)
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank == 0): # only the first GPU saves checkpoint
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
'scaler': scaler.state_dict(),
}, is_best=False, filename='checkpoint_%04d.pth.tar' % epoch)
if args.rank == 0:
summary_writer.close()