train.py (268 lines of code) (raw):
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse, os, datetime, time
import os.path as osp
from sklearn.metrics import f1_score
import shutil
from datetime import datetime
from tqdm import tqdm
import yaml
import torch
import torch.nn.functional as F
import torch.distributed as dist
from utils.utils import create_dir, set_random_seed, AverageMeter, save_curve, is_parallel, de_parallel
from utils.ltr_metrics import shot_acc
from utils.common import build_dataset, build_model, build_prior
from models.feat_pool import IDFeatPool
# to prevent PIL error from reading large images:
# See https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/162#issuecomment-491115265
# or https://stackoverflow.com/questions/12984426/pil-ioerror-image-file-truncated-with-big-images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_args_parser():
# Training settings
parser = argparse.ArgumentParser(description='PASCL for OOD detection in long-tailed recognition')
parser.add_argument('--gpu', default='2')
parser.add_argument('--num_workers', '--cpus', type=int, default=16, help='number of threads for data loader')
parser.add_argument('--data_root_path', '--drp', default='./data', help='data root path')
parser.add_argument('--dataset', '--ds', default='cifar10', choices=['cifar10', 'cifar100', 'imagenet', 'waterbird'])
parser.add_argument('--ood_aux_dataset', '--ood_ds', default='TinyImages', choices=['TinyImages', 'VOS', 'NPOS', 'CIFAR', 'Texture'])
parser.add_argument('--id_class_number', type=int, default=1000, help='for ImageNet subset')
parser.add_argument('--model', '--md', default='ResNet18', choices=['ResNet18', 'ResNet34', 'ResNet50'], help='which model to use')
parser.add_argument('--imbalance_ratio', '--rho', default=0.01, type=float)
parser.add_argument('--seed', default=None, type=int, help='random seed')
# training params:
parser.add_argument('--batch_size', '-b', type=int, default=256, help='input batch size for training')
parser.add_argument('--test_batch_size', '--tb', type=int, default=1000, help='input batch size for testing')
parser.add_argument('--epochs', '-e', type=int, default=200, help='number of epochs to train')
parser.add_argument('--save_epochs', type=int, default=-1, help='number of epochs to save')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--wd', type=float, default=5e-4, help='weight decay')
parser.add_argument('--momentum', '-m', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay_epochs', '--de', default=[60,80], nargs='+', type=int, help='milestones for multisteps lr decay')
parser.add_argument('--opt', default='adam', choices=['sgd', 'adam'], help='which optimizer to use')
parser.add_argument('--decay', default='cos', choices=['cos', 'multisteps'], help='which lr decay method to use')
parser.add_argument('--Lambda', default=0.5, type=float, help='OE loss term tradeoff hyper-parameter')
parser.add_argument('--Lambda2', default=0.1, type=float, help='Contrastive loss term tradeoff hyper-parameter')
parser.add_argument('--T', default=0.07, type=float, help='Temperature in NT-Xent loss (contrastive loss)')
parser.add_argument('--k', default=0.4, type=float, help='bottom-k classes are taken as tail class')
parser.add_argument('--num_ood_samples', default=30000, type=float, help='Number of OOD samples to use.')
# opmitization params:
parser.add_argument('--logit_adjust', '--tau', default=0., type=float)
parser.add_argument('--ood_metric', default='oe', choices=['oe', 'bkg_c', 'energy', 'bin_disc', 'mc_disc', 'maha',
'ada_bin_disc', 'ada_oe', 'ada_energy', 'ada_pascl', 'ada_maha'], help='OOD training metric')
parser.add_argument('--aux_ood_loss', default='none', choices=['none', 'pascl', 'simclr'], help='Auxilliary (e.g., feature-level) OOD training loss')
parser.add_argument('--early-stop', action='store_true', default=True, dest='early_stop', help='If true, early stop when lambda dose not change')
parser.add_argument('--no-early-stop', action='store_false', dest='early_stop')
parser.add_argument('--w_beta', default=1.0, type=float)
parser.add_argument('--t_beta', default=2.0, type=float)
#
parser.add_argument('--timestamp', action='store_true', help='If true, attack time stamp after exp str')
parser.add_argument('--resume', type=str, default='', help='Resume from pre-trained models')
parser.add_argument('--save_root_path', '--srp', default='./runs', help='data root path')
# ddp
parser.add_argument('--ddp', action='store_true', help='If true, use distributed data parallel')
parser.add_argument('--ddp_backend', '--ddpbed', default='nccl', choices=['nccl', 'gloo', 'mpi'], help='If true, use distributed data parallel')
parser.add_argument('--num_nodes', default=1, type=int, help='Number of nodes')
parser.add_argument('--node_id', default=0, type=int, help='Node ID')
parser.add_argument('--dist_url', default='tcp://localhost:23456', type=str, help='url used to set up distributed training')
args = parser.parse_args()
assert args.k>0, "When args.k==0, it is just the OE baseline."
if args.dataset == 'imagenet':
# adjust learning rate:
args.lr *= args.batch_size / 256. # linearly scaled to batch size
return args
def create_save_path(args, _mkdir=True):
# mkdirs:
decay_str = args.decay
if args.decay == 'multisteps':
decay_str += '-'.join(map(str, args.decay_epochs))
opt_str = args.opt
if args.opt == 'sgd':
opt_str += '-m%s' % args.momentum
opt_str = 'e%d-b%d-%s-lr%s-wd%s-%s' % (args.epochs, args.batch_size, opt_str, args.lr, args.wd, decay_str)
reweighting_fn_str = 'sign'
loss_str = '%s-Lambda%s-Lambda2%s-T%s-%s' % \
(args.ood_metric + '-' + args.aux_prior_type + '-' + args.aux_ood_loss,
args.Lambda, args.Lambda2, args.T, reweighting_fn_str)
if args.imbalance_ratio < 1:
if args.logit_adjust > 0:
lt_method = 'LA%s' % args.logit_adjust
else:
lt_method = 'none'
loss_str = lt_method + '-' + loss_str
loss_str += '-k%s'% (args.k)
exp_str = '%s_%s' % (opt_str, loss_str)
if args.timestamp:
exp_str += '_%s' % datetime.datetime.now().strftime("%Y%m%d%H%M%S")
dataset_str = '%s-%s-OOD%d' % (args.dataset, args.imbalance_ratio, args.num_ood_samples) if 'imagenet' not in args.dataset else '%s%d-lt' % (args.dataset, args.id_class_number)
save_dir = osp.join(args.save_root_path, dataset_str, args.model, exp_str)
if _mkdir:
create_dir(save_dir)
print('Saving to %s' % save_dir)
return save_dir
def setup(rank, ngpus_per_node, args):
# initialize the process group
world_size = ngpus_per_node * args.num_nodes
dist.init_process_group(args.ddp_backend, init_method=args.dist_url, rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(gpu_id, ngpus_per_node, args):
save_dir = args.save_dir
# get globale rank (thread id):
rank = args.node_id * ngpus_per_node + gpu_id
print(f"Running on rank {rank}.")
# Initializes ddp:
if args.ddp:
setup(rank, ngpus_per_node, args)
# intialize device:
device = gpu_id if args.ddp else 'cuda'
synthesis_ood_flag = args.ood_aux_dataset in ['VOS', 'NPOS']
require_feats_flag = 'maha' in args.ood_metric
num_classes, train_loader, test_loader, ood_loader, train_sampler, img_num_per_cls_and_ood = build_dataset(args, ngpus_per_node)
img_num_per_cls = img_num_per_cls_and_ood[:num_classes]
model, optimizer, scheduler, num_outputs = build_model(args, num_classes, device, gpu_id)
if require_feats_flag:
model.id_feat_pool = IDFeatPool(num_classes, sample_num=max(img_num_per_cls),
feat_dim=model.penultimate_layer_dim, device=device)
adjustments = build_prior(args, model, img_num_per_cls, num_classes, num_outputs, device)
# train:
if args.resume:
# ckpt = torch.load(osp.join(save_dir, 'latest.pth'), map_location='cpu')
ckpt = torch.load(osp.join(args.resume, 'latest.pth'), map_location='cpu')
if is_parallel(model):
ckpt['model'] = {'module.' + k: v for k, v in ckpt['model'].items()}
model.load_state_dict(ckpt['model'], strict=False)
try:
optimizer.load_state_dict(ckpt['optimizer'])
scheduler.load_state_dict(ckpt['scheduler'])
except:
pass
start_epoch = ckpt['epoch']+1
best_overall_acc = ckpt['best_overall_acc']
training_losses = ckpt['training_losses']
test_clean_losses = ckpt['test_clean_losses']
f1s = ckpt['f1s']
overall_accs = ckpt['overall_accs']
many_accs = ckpt['many_accs']
median_accs = ckpt['median_accs']
low_accs = ckpt['low_accs']
else:
training_losses, test_clean_losses = [], []
f1s, overall_accs, many_accs, median_accs, low_accs = [], [], [], [], []
best_overall_acc = 0
start_epoch = 0
# print('Resume Done.')
fp = open(osp.join(save_dir, 'train_log.txt'), 'a+')
fp_val = open(osp.join(save_dir, 'val_log.txt'), 'a+')
shutil.copyfile('models/base.py', f'{save_dir}/base.py')
for epoch in range(start_epoch, args.epochs):
# reset sampler when using ddp:
if args.ddp:
train_sampler.set_epoch(epoch)
start_time = time.time()
model.train()
training_loss_meter = AverageMeter()
current_lr = scheduler.get_last_lr()
pbar = zip(train_loader, ood_loader)
# if args.ddp and rank == 0:
# pbar = tqdm(pbar, desc=f'Epoch: {epoch:03d}/{args.epochs:03d}', total=len(train_loader))
stop_flag = False
for batch_idx, ((in_data, labels), (ood_data, _)) in enumerate(pbar):
in_data = torch.cat([in_data[0], in_data[1]], dim=0) # shape=(2*N,C,H,W). Two views of each image.
in_data, labels = in_data.to(device), labels.to(device)
ood_data = ood_data.to(device)
# forward:
if not synthesis_ood_flag and not require_feats_flag:
all_data = torch.cat([in_data, ood_data], dim=0) # shape=(2*Nin+Nout,C,W,H)
in_loss, ood_loss, aux_loss = model(all_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args)
elif synthesis_ood_flag:
in_loss, ood_loss, aux_loss, id_feats = \
model(in_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args, ood_data=ood_data, return_features=True)
ood_loader.update(id_feats.detach().clone(), labels)
elif require_feats_flag:
all_data = torch.cat([in_data, ood_data], dim=0) # shape=(2*Nin+Nout,C,W,H)
num_ood = len(ood_data)
in_loss, ood_loss, aux_loss, id_feats = \
model(all_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args, return_features=True)
loss: torch.Tensor = in_loss + args.Lambda * ood_loss + args.Lambda2 * aux_loss
if torch.isnan(loss):
print('Warning: Loss is NaN. Training stopped.')
stop_flag = True
break
if require_feats_flag:
model.id_feat_pool.update(id_feats[-num_ood:].detach().clone(), torch.cat((labels, labels)))
# backward:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# append:
training_loss_meter.append(loss.item())
if rank == 0 and batch_idx % 100 == 0:
train_str = '%s epoch %d batch %d (train): loss %.4f (%.4f, %.4f, %.4f) | lr %s' % (
datetime.now().strftime("%D %H:%M:%S"),
epoch, batch_idx, loss.item(), in_loss.item(), ood_loss.item(), aux_loss.item(), current_lr)
print(train_str)
fp.write(train_str + '\n')
fp.flush()
if stop_flag:
print('Use the model at epoch', epoch - 1)
break
# lr update:
scheduler.step()
if rank == 0:
# eval on clean set:
model.eval()
test_acc_meter, test_loss_meter = AverageMeter(), AverageMeter()
preds_list, labels_list = [], []
with torch.no_grad():
for data, labels in test_loader:
data, labels = data.to(device), labels.to(device)
logits, features = model(data, return_features=True)
in_logits = de_parallel(model).parse_logits(logits, features, args.ood_metric, logits.shape[0])[0]
pred = in_logits.argmax(dim=1, keepdim=True) # get the index of the max log-probability
loss = F.cross_entropy(in_logits, labels)
test_acc_meter.append((in_logits.argmax(1) == labels).float().mean().item())
test_loss_meter.append(loss.item())
preds_list.append(pred)
labels_list.append(labels)
preds = torch.cat(preds_list, dim=0).detach().cpu().numpy().squeeze()
labels = torch.cat(labels_list, dim=0).detach().cpu().numpy()
overall_acc = (preds == labels).sum().item() / len(labels)
f1 = f1_score(labels, preds, average='macro')
many_acc, median_acc, low_acc, _ = shot_acc(preds, labels, img_num_per_cls, acc_per_cls=True)
test_clean_losses.append(test_loss_meter.avg)
f1s.append(f1)
overall_accs.append(overall_acc)
many_accs.append(many_acc)
median_accs.append(median_acc)
low_accs.append(low_acc)
val_str = '%s epoch %d (test): ACC %.4f (%.4f, %.4f, %.4f) | F1 %.4f | time %s' % \
(datetime.now().strftime("%D %H:%M:%S"), epoch, overall_acc, many_acc, median_acc, low_acc, f1, time.time()-start_time)
print(val_str)
fp_val.write(val_str + '\n')
fp_val.flush()
# save curves:
training_losses.append(training_loss_meter.avg)
save_curve(args, save_dir, training_losses, test_clean_losses,
overall_accs, many_accs, median_accs, low_accs, f1s)
# save best model:
model_state_dict = de_parallel(model).state_dict()
if overall_accs[-1] > best_overall_acc and epoch >= args.epochs * 0.75:
best_overall_acc = overall_accs[-1]
torch.save(model_state_dict, osp.join(save_dir, 'best_clean_acc.pth'))
# save feature pool
if synthesis_ood_flag:
ood_loader.save(osp.join(save_dir, 'id_feats.pth'))
elif require_feats_flag:
model.id_feat_pool.save(osp.join(save_dir, 'id_feats.pth')) # exactly the same
# save pth:
torch.save({
'model': model_state_dict,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
'best_overall_acc': best_overall_acc,
'training_losses': training_losses,
'test_clean_losses': test_clean_losses,
'f1s': f1s,
'overall_accs': overall_accs,
'many_accs': many_accs,
'median_accs': median_accs,
'low_accs': low_accs,
},
osp.join(save_dir, 'latest.pth'))
if args.save_epochs > 0 and epoch % args.save_epochs == 0:
torch.save({
'model': model_state_dict,
'optimizer': optimizer.state_dict(),
}, osp.join(save_dir, f'epoch{epoch}.pth'))
if synthesis_ood_flag:
ood_loader.save(osp.join(save_dir, f'id_feats_epoch{epoch}.pth'))
elif require_feats_flag:
model.id_feat_pool.save(osp.join(save_dir, f'id_feats_epoch{epoch}.pth')) # exactly the same
# Clean up ddp:
if args.ddp:
cleanup()
if __name__ == '__main__':
# get args:
args = get_args_parser()
# mkdirs:
save_dir = create_save_path(args)
args.save_dir = save_dir
with open(f'{save_dir}/args.yaml', 'w+') as f:
yaml.safe_dump(vars(args), f, sort_keys=False)
# set CUDA:
if args.num_nodes == 1: # When using multiple nodes, we assume all gpus on each node are available.
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# set random seed, default None
set_random_seed(args.seed)
if args.ddp:
ngpus_per_node = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(ngpus_per_node,args), nprocs=ngpus_per_node, join=True)
else:
train(0, 0, args)