in src/snn_fine_tune.py [0:0]
def main(args):
# -- META
model_name = args['meta']['model_name']
load_checkpoint = args['meta']['load_checkpoint']
copy_data = args['meta']['copy_data']
output_dim = args['meta']['output_dim']
use_pred_head = args['meta']['use_pred_head']
use_fp16 = args['meta']['use_fp16']
device = torch.device(args['meta']['device'])
torch.cuda.set_device(device)
# -- DATA
unlabeled_frac = args['data']['unlabeled_frac']
label_smoothing = args['data']['label_smoothing']
normalize = args['data']['normalize']
root_path = args['data']['root_path']
image_folder = args['data']['image_folder']
dataset_name = args['data']['dataset']
subset_path = args['data']['subset_path']
unique_classes = args['data']['unique_classes_per_rank']
data_seed = args['data']['data_seed']
# -- CRITERTION
classes_per_batch = args['criterion']['classes_per_batch']
supervised_views = args['criterion']['supervised_views']
batch_size = args['criterion']['supervised_batch_size']
temperature = args['criterion']['temperature']
# -- OPTIMIZATION
wd = float(args['optimization']['weight_decay'])
num_epochs = args['optimization']['epochs']
use_lars = args['optimization']['use_lars']
warmup = args['optimization']['warmup']
start_lr = args['optimization']['start_lr']
ref_lr = args['optimization']['lr']
final_lr = args['optimization']['final_lr']
momentum = args['optimization']['momentum']
nesterov = args['optimization']['nesterov']
# -- LOGGING
folder = args['logging']['folder']
tag = args['logging']['write_tag']
r_file_enc = args['logging']['pretrain_path']
# -- log/checkpointing paths
r_enc_path = os.path.join(folder, r_file_enc)
w_enc_path = os.path.join(folder, f'{tag}-fine-tune-SNN.pth.tar')
# -- init distributed
world_size, rank = init_distributed()
logger.info(f'initialized rank/world-size: {rank}/{world_size}')
# -- init loss
suncet = init_suncet_loss(
num_classes=classes_per_batch,
batch_size=batch_size*supervised_views,
world_size=world_size,
rank=rank,
temperature=temperature,
device=device)
labels_matrix = make_labels_matrix(
num_classes=classes_per_batch,
s_batch_size=batch_size,
world_size=world_size,
device=device,
unique_classes=unique_classes,
smoothing=label_smoothing)
# -- make data transforms
transform, init_transform = make_transforms(
dataset_name=dataset_name,
subset_path=subset_path,
unlabeled_frac=unlabeled_frac,
training=True,
split_seed=data_seed,
basic_augmentations=True,
normalize=normalize)
(data_loader,
dist_sampler) = init_data(
dataset_name=dataset_name,
transform=transform,
init_transform=init_transform,
supervised_views=supervised_views,
u_batch_size=None,
stratify=True,
s_batch_size=batch_size,
classes_per_batch=classes_per_batch,
unique_classes=unique_classes,
world_size=world_size,
rank=rank,
root_path=root_path,
image_folder=image_folder,
training=True,
copy_data=copy_data)
# -- rough estimate of labeled imgs per class used to set the number of
# fine-tuning iterations
imgs_per_class = int(1300*(1.-unlabeled_frac)) if 'imagenet' in dataset_name else int(5000*(1.-unlabeled_frac))
dist_sampler.set_inner_epochs(imgs_per_class//batch_size)
ipe = len(data_loader)
logger.info(f'initialized data-loader (ipe {ipe})')
# -- init model and optimizer
scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
encoder, optimizer, scheduler = init_model(
device=device,
training=True,
r_enc_path=r_enc_path,
iterations_per_epoch=ipe,
world_size=world_size,
start_lr=start_lr,
ref_lr=ref_lr,
num_epochs=num_epochs,
output_dim=output_dim,
model_name=model_name,
warmup_epochs=warmup,
use_pred_head=use_pred_head,
use_fp16=use_fp16,
wd=wd,
final_lr=final_lr,
momentum=momentum,
nesterov=nesterov,
use_lars=use_lars)
best_acc, val_top1 = None, None
start_epoch = 0
# -- load checkpoint
if load_checkpoint:
encoder, optimizer, scaler, scheduler, start_epoch, best_acc = load_from_path(
r_path=w_enc_path,
encoder=encoder,
opt=optimizer,
scaler=scaler,
sched=scheduler,
device=device,
use_fp16=use_fp16,
ckp=True)
for epoch in range(start_epoch, num_epochs):
def train_step():
# -- update distributed-data-loader epoch
dist_sampler.set_epoch(epoch)
for i, data in enumerate(data_loader):
imgs = torch.cat([s.to(device) for s in data[:-1]], 0)
labels = torch.cat([labels_matrix for _ in range(supervised_views)])
with torch.cuda.amp.autocast(enabled=use_fp16):
optimizer.zero_grad()
z = encoder(imgs)
loss = suncet(z, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
if i % log_freq == 0:
logger.info('[%d, %5d] (loss: %.3f)' % (epoch + 1, i, loss))
with torch.no_grad():
with nostdout():
val_top1, _ = val_run(
pretrained=copy.deepcopy(encoder),
subset_path=subset_path,
unlabeled_frac=unlabeled_frac,
dataset_name=dataset_name,
root_path=root_path,
image_folder=image_folder,
use_pred=use_pred_head,
normalize=normalize,
split_seed=data_seed)
logger.info('[%d] (val: %.3f%%)' % (epoch + 1, val_top1))
train_step()
# -- logging/checkpointing
if (rank == 0) and ((best_acc is None) or (best_acc < val_top1)):
best_acc = val_top1
save_dict = {
'encoder': encoder.state_dict(),
'opt': optimizer.state_dict(),
'sched': scheduler.state_dict(),
'epoch': epoch + 1,
'unlabel_prob': unlabeled_frac,
'world_size': world_size,
'batch_size': batch_size,
'best_top1_acc': best_acc,
'lr': ref_lr,
'amp': scaler.state_dict()
}
torch.save(save_dict, w_enc_path)
logger.info('[%d] (best-val: %.3f%%)' % (epoch + 1, best_acc))