in train.py [0:0]
def train(gpu_id, ngpus_per_node, args):
save_dir = args.save_dir
# get globale rank (thread id):
rank = args.node_id * ngpus_per_node + gpu_id
print(f"Running on rank {rank}.")
# Initializes ddp:
if args.ddp:
setup(rank, ngpus_per_node, args)
# intialize device:
device = gpu_id if args.ddp else 'cuda'
synthesis_ood_flag = args.ood_aux_dataset in ['VOS', 'NPOS']
require_feats_flag = 'maha' in args.ood_metric
num_classes, train_loader, test_loader, ood_loader, train_sampler, img_num_per_cls_and_ood = build_dataset(args, ngpus_per_node)
img_num_per_cls = img_num_per_cls_and_ood[:num_classes]
model, optimizer, scheduler, num_outputs = build_model(args, num_classes, device, gpu_id)
if require_feats_flag:
model.id_feat_pool = IDFeatPool(num_classes, sample_num=max(img_num_per_cls),
feat_dim=model.penultimate_layer_dim, device=device)
adjustments = build_prior(args, model, img_num_per_cls, num_classes, num_outputs, device)
# train:
if args.resume:
# ckpt = torch.load(osp.join(save_dir, 'latest.pth'), map_location='cpu')
ckpt = torch.load(osp.join(args.resume, 'latest.pth'), map_location='cpu')
if is_parallel(model):
ckpt['model'] = {'module.' + k: v for k, v in ckpt['model'].items()}
model.load_state_dict(ckpt['model'], strict=False)
try:
optimizer.load_state_dict(ckpt['optimizer'])
scheduler.load_state_dict(ckpt['scheduler'])
except:
pass
start_epoch = ckpt['epoch']+1
best_overall_acc = ckpt['best_overall_acc']
training_losses = ckpt['training_losses']
test_clean_losses = ckpt['test_clean_losses']
f1s = ckpt['f1s']
overall_accs = ckpt['overall_accs']
many_accs = ckpt['many_accs']
median_accs = ckpt['median_accs']
low_accs = ckpt['low_accs']
else:
training_losses, test_clean_losses = [], []
f1s, overall_accs, many_accs, median_accs, low_accs = [], [], [], [], []
best_overall_acc = 0
start_epoch = 0
# print('Resume Done.')
fp = open(osp.join(save_dir, 'train_log.txt'), 'a+')
fp_val = open(osp.join(save_dir, 'val_log.txt'), 'a+')
shutil.copyfile('models/base.py', f'{save_dir}/base.py')
for epoch in range(start_epoch, args.epochs):
# reset sampler when using ddp:
if args.ddp:
train_sampler.set_epoch(epoch)
start_time = time.time()
model.train()
training_loss_meter = AverageMeter()
current_lr = scheduler.get_last_lr()
pbar = zip(train_loader, ood_loader)
# if args.ddp and rank == 0:
# pbar = tqdm(pbar, desc=f'Epoch: {epoch:03d}/{args.epochs:03d}', total=len(train_loader))
stop_flag = False
for batch_idx, ((in_data, labels), (ood_data, _)) in enumerate(pbar):
in_data = torch.cat([in_data[0], in_data[1]], dim=0) # shape=(2*N,C,H,W). Two views of each image.
in_data, labels = in_data.to(device), labels.to(device)
ood_data = ood_data.to(device)
# forward:
if not synthesis_ood_flag and not require_feats_flag:
all_data = torch.cat([in_data, ood_data], dim=0) # shape=(2*Nin+Nout,C,W,H)
in_loss, ood_loss, aux_loss = model(all_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args)
elif synthesis_ood_flag:
in_loss, ood_loss, aux_loss, id_feats = \
model(in_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args, ood_data=ood_data, return_features=True)
ood_loader.update(id_feats.detach().clone(), labels)
elif require_feats_flag:
all_data = torch.cat([in_data, ood_data], dim=0) # shape=(2*Nin+Nout,C,W,H)
num_ood = len(ood_data)
in_loss, ood_loss, aux_loss, id_feats = \
model(all_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args, return_features=True)
loss: torch.Tensor = in_loss + args.Lambda * ood_loss + args.Lambda2 * aux_loss
if torch.isnan(loss):
print('Warning: Loss is NaN. Training stopped.')
stop_flag = True
break
if require_feats_flag:
model.id_feat_pool.update(id_feats[-num_ood:].detach().clone(), torch.cat((labels, labels)))
# backward:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# append:
training_loss_meter.append(loss.item())
if rank == 0 and batch_idx % 100 == 0:
train_str = '%s epoch %d batch %d (train): loss %.4f (%.4f, %.4f, %.4f) | lr %s' % (
datetime.now().strftime("%D %H:%M:%S"),
epoch, batch_idx, loss.item(), in_loss.item(), ood_loss.item(), aux_loss.item(), current_lr)
print(train_str)
fp.write(train_str + '\n')
fp.flush()
if stop_flag:
print('Use the model at epoch', epoch - 1)
break
# lr update:
scheduler.step()
if rank == 0:
# eval on clean set:
model.eval()
test_acc_meter, test_loss_meter = AverageMeter(), AverageMeter()
preds_list, labels_list = [], []
with torch.no_grad():
for data, labels in test_loader:
data, labels = data.to(device), labels.to(device)
logits, features = model(data, return_features=True)
in_logits = de_parallel(model).parse_logits(logits, features, args.ood_metric, logits.shape[0])[0]
pred = in_logits.argmax(dim=1, keepdim=True) # get the index of the max log-probability
loss = F.cross_entropy(in_logits, labels)
test_acc_meter.append((in_logits.argmax(1) == labels).float().mean().item())
test_loss_meter.append(loss.item())
preds_list.append(pred)
labels_list.append(labels)
preds = torch.cat(preds_list, dim=0).detach().cpu().numpy().squeeze()
labels = torch.cat(labels_list, dim=0).detach().cpu().numpy()
overall_acc = (preds == labels).sum().item() / len(labels)
f1 = f1_score(labels, preds, average='macro')
many_acc, median_acc, low_acc, _ = shot_acc(preds, labels, img_num_per_cls, acc_per_cls=True)
test_clean_losses.append(test_loss_meter.avg)
f1s.append(f1)
overall_accs.append(overall_acc)
many_accs.append(many_acc)
median_accs.append(median_acc)
low_accs.append(low_acc)
val_str = '%s epoch %d (test): ACC %.4f (%.4f, %.4f, %.4f) | F1 %.4f | time %s' % \
(datetime.now().strftime("%D %H:%M:%S"), epoch, overall_acc, many_acc, median_acc, low_acc, f1, time.time()-start_time)
print(val_str)
fp_val.write(val_str + '\n')
fp_val.flush()
# save curves:
training_losses.append(training_loss_meter.avg)
save_curve(args, save_dir, training_losses, test_clean_losses,
overall_accs, many_accs, median_accs, low_accs, f1s)
# save best model:
model_state_dict = de_parallel(model).state_dict()
if overall_accs[-1] > best_overall_acc and epoch >= args.epochs * 0.75:
best_overall_acc = overall_accs[-1]
torch.save(model_state_dict, osp.join(save_dir, 'best_clean_acc.pth'))
# save feature pool
if synthesis_ood_flag:
ood_loader.save(osp.join(save_dir, 'id_feats.pth'))
elif require_feats_flag:
model.id_feat_pool.save(osp.join(save_dir, 'id_feats.pth')) # exactly the same
# save pth:
torch.save({
'model': model_state_dict,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
'best_overall_acc': best_overall_acc,
'training_losses': training_losses,
'test_clean_losses': test_clean_losses,
'f1s': f1s,
'overall_accs': overall_accs,
'many_accs': many_accs,
'median_accs': median_accs,
'low_accs': low_accs,
},
osp.join(save_dir, 'latest.pth'))
if args.save_epochs > 0 and epoch % args.save_epochs == 0:
torch.save({
'model': model_state_dict,
'optimizer': optimizer.state_dict(),
}, osp.join(save_dir, f'epoch{epoch}.pth'))
if synthesis_ood_flag:
ood_loader.save(osp.join(save_dir, f'id_feats_epoch{epoch}.pth'))
elif require_feats_flag:
model.id_feat_pool.save(osp.join(save_dir, f'id_feats_epoch{epoch}.pth')) # exactly the same
# Clean up ddp:
if args.ddp:
cleanup()