in main_moco.py [0:0]
def main_worker(gpu, ngpus_per_node, args):
global best_acc1
args.gpu = gpu
# suppress printing if not master
if args.multiprocessing_distributed and args.gpu != 0:
def print_pass(*args):
pass
builtins.print = print_pass
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# create model
print("=> creating model '{}' with backbone '{}'".format(args.arch, args.backbone))
model_func = get_moco_model(args.arch)
norm = get_norm(args.norm)
model = model_func(
backbone_models.__dict__[args.backbone],
args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.num_mlp, norm)
print(model)
if args.pretrained:
if os.path.isfile(args.pretrained):
print("=> loading pretrained model from '{}'".format(args.pretrained))
state_dict = torch.load(args.pretrained, map_location="cpu")['state_dict']
for k in list(state_dict.keys()):
new_key = k.replace("module.", "")
state_dict[new_key] = state_dict[k]
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print("=> loaded pretrained model from '{}'".format(args.pretrained))
if len(msg.missing_keys) > 0:
print("missing keys: {}".format(msg.missing_keys))
if len(msg.unexpected_keys) > 0:
print("unexpected keys: {}".format(msg.unexpected_keys))
else:
print("=> no pretrained model found at '{}'".format(args.pretrained))
if args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
# comment out the following line for debugging
#raise NotImplementedError("Only DistributedDataParallel is supported.")
else:
model = torch.nn.DataParallel(model).cuda()
# AllGather implementation (batch shuffle, queue update, etc.) in
# this code only supports DistributedDataParallel.
#raise NotImplementedError("Only DistributedDataParallel is supported.")
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
if 'best_acc1' in checkpoint:
best_acc1 = checkpoint['best_acc1']
#if args.gpu is not None:
# # best_acc1 may be from a checkpoint from a different GPU
# best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
train_transform = data_transforms.get_transforms("MoCoV2")
train_dataset = datasets.ImageFolder(
traindir,
data_transforms.TwoCropsTransform(train_transform, train_transform))
print("train_dataset:\n{}".format(train_dataset))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
val_loader_base = torch.utils.data.DataLoader(
datasets.ImageFolderWithPercent(
traindir,
data_transforms.get_transforms("DefaultVal"),
percent=args.nn_mem_percent
),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
val_loader_query = torch.utils.data.DataLoader(
datasets.ImageFolderWithPercent(
valdir,
data_transforms.get_transforms("DefaultVal"),
percent=args.nn_query_percent
),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
ss_validate(val_loader_base, val_loader_query, model, args)
return
best_epoch = args.start_epoch
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
if epoch >= args.warmup_epoch:
lr_schedule.adjust_learning_rate_with_min(optimizer, epoch, args)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, args)
is_best = False
if (epoch + 1) % args.eval_freq == 0:
acc1 = ss_validate(val_loader_base, val_loader_query, model, args)
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if is_best:
best_epoch = epoch
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
utils.save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
}, is_best=is_best, epoch=epoch, args=args)
print('Best Acc@1 {0} @ epoch {1}'.format(best_acc1, best_epoch + 1))