in main_linear.py [0:0]
def main(args):
utils.init_distributed_mode(args)
global best_acc1
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
linear_keyword = 'head'
if os.path.isfile(args.pretrained):
print("=> loading checkpoint '{}'".format(args.pretrained))
if args.gpu is None:
checkpoint = torch.load(args.pretrained)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.pretrained, map_location=loc)
visual_keyword = 'module.visual.'
# rename CLIP pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
# retain only base_encoder up to before the embedding layer
if k.startswith(visual_keyword) and not k.startswith(visual_keyword + linear_keyword):
# remove prefix
state_dict[k[len(visual_keyword):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
else:
raise Exception('Missing pretrained model checkpoint: {}'.format(args.pretrained))
# create model
print("=> creating model '{}'".format(args.arch))
model = timm.models.create_model(args.arch, num_classes=1000)
args.start_epoch = 0
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword}
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
param.requires_grad = False
# init the fc layer
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
getattr(model, linear_keyword).bias.data.zero_()
init_lr = args.lr * int(args.batch_size / utils.get_world_size()) / 256
args.workers = int((args.workers + utils.get_world_size() - 1) / utils.get_world_size())
model.cuda(args.gpu)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # weight, bias
optimizer = torch.optim.SGD(parameters, init_lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
if args.gpu is not None:
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
# Data loading code
cwd = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(cwd, 'dataset_catalog.json')) as f:
catalog = json.load(f)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
lambda x: x.convert('RGB'),
transforms.ToTensor(),
normalize,
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
lambda x: x.convert('RGB'),
transforms.ToTensor(),
normalize,
])
train_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=True, transform=train_transform)
val_dataset = datasets.get_downstream_dataset(catalog, args.dataset, is_train=False, transform=val_transform)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=256, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
validate(val_loader, model, criterion, args)
return
print(args)
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, init_lr, epoch, args)
# train for one epoch
train_stats = train(train_loader, model, criterion, optimizer, epoch, args)
if (epoch + 1) % args.eval_freq != 0:
continue
# evaluate on validation set
val_stats = validate(val_loader, model, criterion, args)
acc1 = val_stats['acc1']
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if utils.is_main_process(): # only the first GPU saves checkpoint
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer' : optimizer.state_dict(),
}, is_best, args.output_dir)
if epoch == args.start_epoch:
sanity_check(model.state_dict(), args.pretrained, linear_keyword, visual_keyword)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in val_stats.items()},
'epoch': epoch}
if utils.is_main_process():
with open(os.path.join(args.output_dir, 'linear_{}_lr={}_log.txt'.format(args.dataset, args.lr)), 'a') as f:
f.write(json.dumps(log_stats) + '\n')