utils/common.py (217 lines of code) (raw):
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from datasets.ImbalanceCIFAR import IMBALANCECIFAR10, IMBALANCECIFAR100
from datasets.ImbalanceImageNet import LT_Dataset
from datasets.tinyimages_300k import TinyImages
from models.base import BaseModel
from models.resnet import ResNet34, ResNet18
from models.resnet_imagenet import ResNet50# , ResNet18
from models.feat_pool import IDFeatPool
from utils.utils import TwoCropTransform, de_parallel
def build_dataset(args, ngpus_per_node, is_training=True):
# get batch size:
train_batch_size = args.batch_size if not args.ddp else int(args.batch_size/ngpus_per_node/args.num_nodes)
num_workers = args.num_workers if not args.ddp else int((args.num_workers+ngpus_per_node)/ngpus_per_node)
# data:
if args.dataset in ['cifar10', 'cifar100']:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
elif args.dataset == 'imagenet' or args.dataset == 'waterbird':
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
if args.dataset == 'cifar10':
num_classes = 10
train_set = IMBALANCECIFAR10(train=True, transform=TwoCropTransform(train_transform), imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
test_set = IMBALANCECIFAR10(train=False, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'cifar100':
num_classes = 100
train_set = IMBALANCECIFAR100(train=True, transform=TwoCropTransform(train_transform), imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
test_set = IMBALANCECIFAR100(train=False, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'imagenet':
num_classes = args.id_class_number
train_set = LT_Dataset(
osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_train.txt', transform=TwoCropTransform(train_transform),
subset_class_idx=np.arange(0,args.id_class_number))
if is_training:
test_set = LT_Dataset(
osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_val.txt', transform=test_transform,
subset_class_idx=np.arange(0,args.id_class_number))
else:
test_set = ImageFolder(osp.join(args.data_root_path, 'imagenet', 'val'), transform=test_transform)
elif args.dataset == 'waterbird':
num_classes = 2
train_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'train'), transform=TwoCropTransform(train_transform))
setattr(train_set, 'img_num_per_cls', [363, 3699])
if is_training:
test_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'val'), transform=test_transform)
else:
test_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'test'), transform=test_transform)
if args.ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
else:
train_sampler = None
train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
drop_last=True, pin_memory=True, sampler=train_sampler)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=num_workers,
drop_last=False, pin_memory=True)
if is_training:
if args.ood_aux_dataset in ['TinyImages', 'ImExtra', 'Texture']:
if args.dataset in ['cifar10', 'cifar100']:
ood_set = Subset(TinyImages(args.data_root_path, transform=train_transform), list(range(args.num_ood_samples)))
elif args.dataset == 'imagenet':
ood_set = ImageFolder(osp.join(args.data_root_path, 'imagenet/extra_1k'), transform=train_transform)
elif args.dataset == 'waterbird':
ood_set = ImageFolder(osp.join(args.data_root_path, 'texture'), transform=train_transform)
else:
raise NotImplementedError(args.dataset, args.ood_aux_dataset)
if args.ddp:
ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set)
else:
ood_sampler = None
ood_loader = DataLoader(ood_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
drop_last=True, pin_memory=True, sampler=ood_sampler)
ood_num = len(ood_set)
print('Training on %s with %d images and %d validation images | %d OOD training images.' % (args.dataset, len(train_set), len(test_set), len(ood_set)))
elif args.ood_aux_dataset in ['VOS', 'NPOS']:
# sample_num = min(train_set.img_num_per_cls) * 10
sample_num = max(train_set.img_num_per_cls)
feat_dim = 512 if 'cifar' in args.dataset else 1024 # TODO: check
mode = args.ood_aux_dataset
device = 'cuda:0'
ood_loader = IDFeatPool(num_classes, sample_num=sample_num, feat_dim=feat_dim, mode=mode, device=device)
ood_num = num_classes * sample_num
elif args.ood_aux_dataset in ['CIFAR']:
if args.dataset == 'cifar10':
dout = 'cifar100'
elif args.dataset == 'cifar100':
dout = 'cifar10'
ood_set = ImageFolder(osp.join(args.data_root_path, f'SCOOD/data/images/{dout}/test'), transform=train_transform)
if args.ddp:
ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set)
else:
ood_sampler = None
ood_loader = DataLoader(ood_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
drop_last=True, pin_memory=True, sampler=ood_sampler)
ood_num = len(ood_set)
print('Training on %s with %d images and %d validation images | %d OOD training images.' % (args.dataset, len(train_set), len(test_set), len(ood_set)))
else:
raise NotImplementedError(f'{args.dataset} v.s. {args.ood_aux_dataset}')
img_num_per_cls_and_ood = np.array(train_set.img_num_per_cls + [ood_num])
return num_classes, train_loader, test_loader, ood_loader, train_sampler, img_num_per_cls_and_ood
else:
img_num_per_cls_and_ood = np.array(train_set.img_num_per_cls + [args.num_ood_samples])
return num_classes, test_loader, img_num_per_cls_and_ood
def build_plain_train_loader(args): # for statistic during test
if args.dataset in ['cifar10', 'cifar100']:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
elif args.dataset == 'imagenet':
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
if args.dataset == 'cifar10':
num_classes = 10
train_set = IMBALANCECIFAR10(train=True, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'cifar100':
num_classes = 100
train_set = IMBALANCECIFAR100(train=True, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'imagenet':
num_classes = args.id_class_number
train_set = LT_Dataset(
osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_train.txt', transform=test_transform,
subset_class_idx=np.arange(0,args.id_class_number))
train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=4,
drop_last=True, pin_memory=True)
return train_loader
def build_model(args, num_classes, device, gpu_id, return_features=False, is_training=True):
# model:
num_outputs = num_classes
if any(x in args.ood_metric for x in ['bkg_c', 'bin_disc']):
num_outputs += 1
elif any(x in args.ood_metric for x in ['mc_disc']):
num_outputs += 2
# 'ResNet18', 'ResNet34', or 'ResNet50'
model: BaseModel = eval(args.model)(num_classes=num_classes, num_outputs=num_outputs, return_features=return_features).to(device)
if args.ddp:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[gpu_id], broadcast_buffers=False,
find_unused_parameters=True
)
else:
# model = torch.nn.DataParallel(model)
pass
# print('Model Done.')
if is_training:
# optimizer:
if args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
elif args.opt == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum, nesterov=True)
if args.decay == 'cos':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
elif args.decay == 'multisteps':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.decay_epochs, gamma=0.1)
# print('Optimizer Done.')
return model, optimizer, scheduler, num_outputs
else:
return model, num_outputs
def build_prior(args, model, img_num_per_cls, num_classes, num_outputs, device):
img_num_per_cls = torch.from_numpy(img_num_per_cls).to(device)
if args.logit_adjust > 0:
adjustments = img_num_per_cls / img_num_per_cls.sum()
adjustments = args.logit_adjust * torch.log(adjustments + 1e-12)[None, :]
if args.ood_metric in ['bkg_c'] and adjustments.shape[1] != num_outputs:
placeholder = torch.zeros_like(adjustments[:, :num_outputs - num_classes])
adjustments = torch.cat((adjustments, placeholder), dim=1)
else:
if args.ood_metric in ['bkg_c']:
adjustments = torch.zeros((1, num_outputs), device=device)
else:
adjustments = torch.zeros((1, num_classes), device=device)
return adjustments