in eval_linear.py [0:0]
def main():
global args, best_acc
args = parser.parse_args()
init_distributed_mode(args)
fix_random_seeds(args.seed)
logger, training_stats = initialize_exp(
args, "epoch", "loss", "prec1", "prec5", "loss_val", "prec1_val", "prec5_val"
)
# build data
train_dataset = datasets.ImageFolder(os.path.join(args.data_path, "train"))
val_dataset = datasets.ImageFolder(os.path.join(args.data_path, "val"))
tr_normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]
)
train_dataset.transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
tr_normalize,
])
val_dataset.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
tr_normalize,
])
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
)
logger.info("Building data done")
# build model
model = resnet_models.__dict__[args.arch](output_dim=0, eval_mode=True)
linear_classifier = RegLog(1000, args.arch, args.global_pooling, args.use_bn)
# convert batch norm layers (if any)
linear_classifier = nn.SyncBatchNorm.convert_sync_batchnorm(linear_classifier)
# model to gpu
model = model.cuda()
linear_classifier = linear_classifier.cuda()
linear_classifier = nn.parallel.DistributedDataParallel(
linear_classifier,
device_ids=[args.gpu_to_work_on],
find_unused_parameters=True,
)
model.eval()
# load weights
if os.path.isfile(args.pretrained):
state_dict = torch.load(args.pretrained, map_location="cuda:" + str(args.gpu_to_work_on))
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# remove prefixe "module."
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
for k, v in model.state_dict().items():
if k not in list(state_dict):
logger.info('key "{}" could not be found in provided state dict'.format(k))
elif state_dict[k].shape != v.shape:
logger.info('key "{}" is of different shape in model and provided state dict'.format(k))
state_dict[k] = v
msg = model.load_state_dict(state_dict, strict=False)
logger.info("Load pretrained model with msg: {}".format(msg))
else:
logger.info("No pretrained weights found => training with random weights")
# set optimizer
optimizer = torch.optim.SGD(
linear_classifier.parameters(),
lr=args.lr,
nesterov=args.nesterov,
momentum=0.9,
weight_decay=args.wd,
)
# set scheduler
if args.scheduler_type == "step":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, args.decay_epochs, gamma=args.gamma
)
elif args.scheduler_type == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, args.epochs, eta_min=args.final_lr
)
# Optionally resume from a checkpoint
to_restore = {"epoch": 0, "best_acc": 0.}
restart_from_checkpoint(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
run_variables=to_restore,
state_dict=linear_classifier,
optimizer=optimizer,
scheduler=scheduler,
)
start_epoch = to_restore["epoch"]
best_acc = to_restore["best_acc"]
cudnn.benchmark = True
for epoch in range(start_epoch, args.epochs):
# train the network for one epoch
logger.info("============ Starting epoch %i ... ============" % epoch)
# set samplers
train_loader.sampler.set_epoch(epoch)
scores = train(model, linear_classifier, optimizer, train_loader, epoch)
scores_val = validate_network(val_loader, model, linear_classifier)
training_stats.update(scores + scores_val)
scheduler.step()
# save checkpoint
if args.rank == 0:
save_dict = {
"epoch": epoch + 1,
"state_dict": linear_classifier.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"best_acc": best_acc,
}
torch.save(save_dict, os.path.join(args.dump_path, "checkpoint.pth.tar"))
logger.info("Training of the supervised linear classifier on frozen features completed.\n"
"Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))