in main.py [0:0]
def main(args):
utils.init_distributed_mode(args)
global best_acc1
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# create model
print("=> creating model: {}".format(args.model))
model = getattr(models, args.model)(ssl_mlp_dim=args.ssl_mlp_dim, ssl_emb_dim=args.ssl_emb_dim)
model.cuda(args.gpu)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], bucket_cap_mb=200)
# define loss function (criterion) and optimizer
criterion = models.get_loss(args.model, args.ssl_temp, args.ssl_scale).cuda(args.gpu)
p_wd, p_non_wd = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue # frozen weights
if p.ndim < 2 or 'bias' in n or 'ln' in n or 'bn' in n:
p_non_wd.append(p)
else:
p_wd.append(p)
optim_params = [{"params": p_wd, "weight_decay": args.wd},
{"params": p_non_wd, "weight_decay": 0}]
optimizer = torch.optim.AdamW(optim_params, lr=args.lr, betas=args.betas,
eps=args.eps, weight_decay=args.wd)
scaler = amp.GradScaler(enabled=not args.disable_amp)
# optionally resume from a checkpoint (takes precedence over autoresume)
if args.resume:
if os.path.isfile(args.resume):
print("=> loading resume checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location='cpu')
epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
args.start_epoch = epoch
result = model.load_state_dict(checkpoint['state_dict'], strict=False)
print(result)
optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else ()
scaler.load_state_dict(checkpoint['scaler']) if 'scaler' in checkpoint else ()
best_acc1 = checkpoint['best_acc1']
print("=> loaded resume checkpoint '{}' (epoch {})"
.format(args.resume, epoch))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
else:
# auto-resume from latest checkpoint in output directory
latest = os.path.join(args.output_dir, 'checkpoint.pt')
if os.path.isfile(latest):
print("=> loading latest checkpoint '{}'".format(latest))
latest_checkpoint = torch.load(latest, map_location='cpu')
args.start_epoch = latest_checkpoint['epoch']
model.load_state_dict(latest_checkpoint['state_dict'])
optimizer.load_state_dict(latest_checkpoint['optimizer'])
scaler.load_state_dict(latest_checkpoint['scaler'])
best_acc1 = latest_checkpoint['best_acc1']
print("=> loaded latest checkpoint '{}' (epoch {})"
.format(latest, latest_checkpoint['epoch']))
cudnn.benchmark = True
# Data loading code
print("=> creating dataset")
tokenizer = SimpleTokenizer()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
transforms.ToTensor(),
normalize
])
val_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
train_dataset = datasets.get_dataset(train_transform, tokenizer, args)
cwd = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(cwd, 'dataset_catalog.json')) as f:
root = json.load(f)['imagenet']['path']
val_dataset = ImageFolder(os.path.join(root, 'val'), val_transform)
# dist eval resamples data to pad uneven batch sizes
# make sure num_samples = 0 mod num_gpus for exact acc
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
else:
train_sampler = None
val_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
if args.evaluate:
if args.model.startswith('SIMCLR'):
print('zero-shot evaluation not supported with ssl-only model.')
return
zero_stats = validate_zeroshot(val_loader, model, tokenizer, args)
if utils.is_main_process():
with open(os.path.join(args.output_dir, 'eval_log.txt'), 'a') as f:
f.write(json.dumps(zero_stats) + '\n')
return
lr_schedule = utils.cosine_scheduler(args.lr, args.lr_end, args.epochs,
len(train_loader) // args.update_freq, warmup_epochs=args.warmup_epochs, start_warmup_value=args.lr_start)
if utils.is_main_process() and args.wandb:
wandb_id = os.path.split(args.output_dir)[-1]
wandb.init(project='slip', id=wandb_id, config=args, resume='allow')
print(args)
print("=> beginning training")
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train_stats = train(train_loader, model, criterion, optimizer, scaler, epoch, lr_schedule, args)
if (epoch + 1) % args.eval_freq != 0:
continue
if args.model.startswith('SIMCLR'):
val_stats = {'acc1': -1}
acc1 = -1
else:
val_stats = validate_zeroshot(val_loader, model, tokenizer, args)
acc1 = val_stats['acc1']
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
print("=> saving checkpoint")
utils.save_on_master({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
'scaler': scaler.state_dict(),
'best_acc1': best_acc1,
'args': args,
}, is_best, args.output_dir)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in val_stats.items()},
'epoch': epoch}
if utils.is_main_process():
if args.wandb:
wandb.log(log_stats)
with open(os.path.join(args.output_dir, 'log.txt'), 'a') as f:
f.write(json.dumps(log_stats) + '\n')