in main_fixmatch.py [0:0]
def main_worker(gpu, ngpus_per_node, args):
global best_acc1
args.gpu = gpu
# suppress printing if not master
if args.multiprocessing_distributed and args.gpu != 0:
def print_pass(*args):
pass
builtins.print = print_pass
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# create model
print("=> creating model '{}' with backbone '{}'".format(args.arch, args.backbone))
model_func = get_fixmatch_model(args.arch)
norm = get_norm(args.norm)
model = model_func(
backbone_models.__dict__[args.backbone],
eman=args.eman,
momentum=args.ema_m,
norm=norm
)
print(model)
if args.self_pretrained:
if os.path.isfile(args.self_pretrained):
print("=> loading checkpoint '{}'".format(args.self_pretrained))
checkpoint = torch.load(args.self_pretrained, map_location="cpu")
# rename self pre-trained keys to model.main keys
state_dict = checkpoint['state_dict']
model_prefix = 'module.' + args.model_prefix
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith(model_prefix) and not k.startswith(model_prefix + '.fc'):
# replace prefix
new_key = k.replace(model_prefix, "main.backbone")
state_dict[new_key] = state_dict[k]
if model.ema is not None:
new_key = k.replace(model_prefix, "ema.backbone")
state_dict[new_key] = state_dict[k].clone()
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
if len(msg.missing_keys) > 0:
print("missing keys:\n{}".format('\n'.join(msg.missing_keys)))
if len(msg.unexpected_keys) > 0:
print("unexpected keys:\n{}".format('\n'.join(msg.unexpected_keys)))
print("=> loaded pre-trained model '{}' (epoch {})".format(args.self_pretrained, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.self_pretrained))
elif args.pretrained:
if os.path.isfile(args.pretrained):
print("=> loading pretrained model from '{}'".format(args.pretrained))
checkpoint = torch.load(args.pretrained, map_location="cpu")
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
new_key = k.replace("module.", "")
state_dict[new_key] = state_dict[k]
del state_dict[k]
model_num_cls = state_dict['fc.weight'].shape[0]
if model_num_cls != args.cls:
# if num_cls don't match, remove the last layer
del state_dict['fc.weight']
del state_dict['fc.bias']
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}, \
"missing keys:\n{}".format('\n'.join(msg.missing_keys))
else:
model.load_state_dict(state_dict)
print("=> loaded pre-trained model '{}' (epoch {})".format(args.pretrained, checkpoint['epoch']))
else:
print("=> no pretrained model found at '{}'".format(args.pretrained))
if args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
else:
# DataParallel will divide and allocate batch_size to all available GPUs
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
# define loss function (criterion)
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
# define optimizer
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=args.nesterov)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
if args.gpu is not None:
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
# Supervised Data loading code
if args.trainindex_x is not None and args.trainindex_u is not None:
print("load index from {}/{}".format(args.trainindex_x, args.trainindex_u))
index_info_x = os.path.join(args.data, 'indexes', args.trainindex_x)
index_info_u = os.path.join(args.data, 'indexes', args.trainindex_u)
index_info_x = pd.read_csv(index_info_x)
trainindex_x = index_info_x['Index'].tolist()
index_info_u = pd.read_csv(index_info_u)
trainindex_u = index_info_u['Index'].tolist()
train_dataset_x, train_dataset_u, val_dataset = get_imagenet_ssl(
args.data, trainindex_x, trainindex_u,
weak_type=args.weak_type, strong_type=args.strong_type)
else:
print("random sampling {} percent of data".format(args.anno_percent * 100))
train_dataset_x, train_dataset_u, val_dataset = get_imagenet_ssl_random(
args.data, args.anno_percent, weak_type=args.weak_type, strong_type=args.strong_type)
print("train_dataset_x:\n{}".format(train_dataset_x))
print("train_dataset_u:\n{}".format(train_dataset_u))
print("val_dataset:\n{}".format(val_dataset))
# Data loading code
train_sampler = DistributedSampler if args.distributed else RandomSampler
train_loader_x = DataLoader(
train_dataset_x,
sampler=train_sampler(train_dataset_x),
batch_size=args.batch_size,
num_workers=args.workers, pin_memory=True, drop_last=True)
train_loader_u = DataLoader(
train_dataset_u,
sampler=train_sampler(train_dataset_u),
batch_size=args.batch_size * args.mu,
num_workers=args.workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(
val_dataset,
batch_size=128, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
validate(val_loader, model, criterion, args)
return
best_epoch = args.start_epoch
for epoch in range(args.start_epoch, args.epochs):
if epoch >= args.warmup_epoch:
lr_schedule.adjust_learning_rate(optimizer, epoch, args)
# train for one epoch
train(train_loader_x, train_loader_u, model, optimizer, epoch, args)
is_best = False
if (epoch + 1) % args.eval_freq == 0:
# evaluate on validation set
acc1 = validate(val_loader, model, criterion, args)
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if is_best:
best_epoch = epoch
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
}, is_best)
print('Best Acc@1 {0} @ epoch {1}'.format(best_acc1, best_epoch + 1))