in main_byol.py [0:0]
def main_worker(gpu, ngpus_per_node, args):
global best_acc1
args.gpu = gpu
# suppress printing if not master
if args.multiprocessing_distributed and args.gpu != 0:
def print_pass(*args):
pass
builtins.print = print_pass
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# create model
print("=> creating model '{}' with backbone '{}'".format(args.arch, args.backbone))
model_func = get_byol_model(args.arch)
norm_layer = get_norm(args.norm)
model = model_func(
backbone_models.__dict__[args.backbone],
dim=args.byol_dim,
m=args.byol_m,
hid_dim=args.hid_dim,
norm_layer=norm_layer,
num_neck_mlp=args.num_neck_mlp,
)
print(model)
if args.pretrained:
if os.path.isfile(args.pretrained):
print("=> loading pretrained model from '{}'".format(args.pretrained))
state_dict = torch.load(args.pretrained, map_location="cpu")['state_dict']
# rename state_dict keys
for k in list(state_dict.keys()):
new_key = k.replace("module.", "")
state_dict[new_key] = state_dict[k]
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print("=> loaded pretrained model from '{}'".format(args.pretrained))
if len(msg.missing_keys) > 0:
print("missing keys: {}".format(msg.missing_keys))
if len(msg.unexpected_keys) > 0:
print("unexpected keys: {}".format(msg.unexpected_keys))
else:
print("=> no pretrained model found at '{}'".format(args.pretrained))
if args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
else:
model = torch.nn.DataParallel(model).cuda()
# define optimizer
params = collect_params(model, exclude_bias_and_bn=True)
optimizer = LARS(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
if 'best_acc1' in checkpoint:
best_acc1 = checkpoint['best_acc1']
#if args.gpu is not None:
# # best_acc1 may be from a checkpoint from a different GPU
# best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
transform1, transform2 = data_transforms.get_byol_tranforms()
train_dataset = datasets.ImageFolder(
traindir,
data_transforms.TwoCropsTransform(transform1, transform2))
print("train_dataset:\n{}".format(train_dataset))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
val_loader_base = torch.utils.data.DataLoader(
custom_datasets.ImageFolderWithPercent(
traindir,
data_transforms.get_transforms("DefaultVal"),
percent=args.nn_mem_percent
),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
val_loader_query = torch.utils.data.DataLoader(
custom_datasets.ImageFolderWithPercent(
valdir,
data_transforms.get_transforms("DefaultVal"),
percent=args.nn_query_percent
),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
ss_validate(val_loader_base, val_loader_query, model, args)
return
best_epoch = args.start_epoch
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
if epoch >= args.warmup_epoch:
lr_schedule.adjust_learning_rate(optimizer, epoch, args)
# train for one epoch
train(train_loader, model, optimizer, epoch, args)
is_best = False
if (epoch + 1) % args.eval_freq == 0:
acc1 = ss_validate(val_loader_base, val_loader_query, model, args)
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if is_best:
best_epoch = epoch
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
utils.save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
}, is_best=is_best, epoch=epoch, args=args)
print('Best Acc@1 {0} @ epoch {1}'.format(best_acc1, best_epoch + 1))