in utils/common.py [0:0]
def build_dataset(args, ngpus_per_node, is_training=True):
# get batch size:
train_batch_size = args.batch_size if not args.ddp else int(args.batch_size/ngpus_per_node/args.num_nodes)
num_workers = args.num_workers if not args.ddp else int((args.num_workers+ngpus_per_node)/ngpus_per_node)
# data:
if args.dataset in ['cifar10', 'cifar100']:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
elif args.dataset == 'imagenet' or args.dataset == 'waterbird':
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
if args.dataset == 'cifar10':
num_classes = 10
train_set = IMBALANCECIFAR10(train=True, transform=TwoCropTransform(train_transform), imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
test_set = IMBALANCECIFAR10(train=False, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'cifar100':
num_classes = 100
train_set = IMBALANCECIFAR100(train=True, transform=TwoCropTransform(train_transform), imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
test_set = IMBALANCECIFAR100(train=False, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'imagenet':
num_classes = args.id_class_number
train_set = LT_Dataset(
osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_train.txt', transform=TwoCropTransform(train_transform),
subset_class_idx=np.arange(0,args.id_class_number))
if is_training:
test_set = LT_Dataset(
osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_val.txt', transform=test_transform,
subset_class_idx=np.arange(0,args.id_class_number))
else:
test_set = ImageFolder(osp.join(args.data_root_path, 'imagenet', 'val'), transform=test_transform)
elif args.dataset == 'waterbird':
num_classes = 2
train_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'train'), transform=TwoCropTransform(train_transform))
setattr(train_set, 'img_num_per_cls', [363, 3699])
if is_training:
test_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'val'), transform=test_transform)
else:
test_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'test'), transform=test_transform)
if args.ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
else:
train_sampler = None
train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
drop_last=True, pin_memory=True, sampler=train_sampler)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=num_workers,
drop_last=False, pin_memory=True)
if is_training:
if args.ood_aux_dataset in ['TinyImages', 'ImExtra', 'Texture']:
if args.dataset in ['cifar10', 'cifar100']:
ood_set = Subset(TinyImages(args.data_root_path, transform=train_transform), list(range(args.num_ood_samples)))
elif args.dataset == 'imagenet':
ood_set = ImageFolder(osp.join(args.data_root_path, 'imagenet/extra_1k'), transform=train_transform)
elif args.dataset == 'waterbird':
ood_set = ImageFolder(osp.join(args.data_root_path, 'texture'), transform=train_transform)
else:
raise NotImplementedError(args.dataset, args.ood_aux_dataset)
if args.ddp:
ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set)
else:
ood_sampler = None
ood_loader = DataLoader(ood_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
drop_last=True, pin_memory=True, sampler=ood_sampler)
ood_num = len(ood_set)
print('Training on %s with %d images and %d validation images | %d OOD training images.' % (args.dataset, len(train_set), len(test_set), len(ood_set)))
elif args.ood_aux_dataset in ['VOS', 'NPOS']:
# sample_num = min(train_set.img_num_per_cls) * 10
sample_num = max(train_set.img_num_per_cls)
feat_dim = 512 if 'cifar' in args.dataset else 1024 # TODO: check
mode = args.ood_aux_dataset
device = 'cuda:0'
ood_loader = IDFeatPool(num_classes, sample_num=sample_num, feat_dim=feat_dim, mode=mode, device=device)
ood_num = num_classes * sample_num
elif args.ood_aux_dataset in ['CIFAR']:
if args.dataset == 'cifar10':
dout = 'cifar100'
elif args.dataset == 'cifar100':
dout = 'cifar10'
ood_set = ImageFolder(osp.join(args.data_root_path, f'SCOOD/data/images/{dout}/test'), transform=train_transform)
if args.ddp:
ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set)
else:
ood_sampler = None
ood_loader = DataLoader(ood_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
drop_last=True, pin_memory=True, sampler=ood_sampler)
ood_num = len(ood_set)
print('Training on %s with %d images and %d validation images | %d OOD training images.' % (args.dataset, len(train_set), len(test_set), len(ood_set)))
else:
raise NotImplementedError(f'{args.dataset} v.s. {args.ood_aux_dataset}')
img_num_per_cls_and_ood = np.array(train_set.img_num_per_cls + [ood_num])
return num_classes, train_loader, test_loader, ood_loader, train_sampler, img_num_per_cls_and_ood
else:
img_num_per_cls_and_ood = np.array(train_set.img_num_per_cls + [args.num_ood_samples])
return num_classes, test_loader, img_num_per_cls_and_ood