def build_dataset()

in utils/common.py [0:0]


def build_dataset(args, ngpus_per_node, is_training=True):
    # get batch size:
    train_batch_size = args.batch_size if not args.ddp else int(args.batch_size/ngpus_per_node/args.num_nodes)
    num_workers = args.num_workers if not args.ddp else int((args.num_workers+ngpus_per_node)/ngpus_per_node)

    # data:
    if args.dataset in ['cifar10', 'cifar100']:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
    elif args.dataset == 'imagenet' or args.dataset == 'waterbird':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    if args.dataset == 'cifar10':
        num_classes = 10
        train_set = IMBALANCECIFAR10(train=True, transform=TwoCropTransform(train_transform), imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
        test_set = IMBALANCECIFAR10(train=False, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_set = IMBALANCECIFAR100(train=True, transform=TwoCropTransform(train_transform), imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
        test_set = IMBALANCECIFAR100(train=False, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
    elif args.dataset == 'imagenet':
        num_classes = args.id_class_number
        train_set = LT_Dataset(
            osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_train.txt', transform=TwoCropTransform(train_transform), 
            subset_class_idx=np.arange(0,args.id_class_number))
        if is_training:
            test_set = LT_Dataset(
                osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_val.txt', transform=test_transform,
                subset_class_idx=np.arange(0,args.id_class_number))
        else:
            test_set = ImageFolder(osp.join(args.data_root_path, 'imagenet', 'val'), transform=test_transform)
    elif args.dataset == 'waterbird':
        num_classes = 2
        train_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'train'), transform=TwoCropTransform(train_transform))
        setattr(train_set, 'img_num_per_cls', [363, 3699])
        if is_training:
            test_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'val'), transform=test_transform)
        else:
            test_set = ImageFolder(osp.join(args.data_root_path, 'waterbird_LT', 'test'), transform=test_transform)
    if args.ddp:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
    else:
        train_sampler = None
    train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
                                drop_last=True, pin_memory=True, sampler=train_sampler)
    test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=num_workers, 
                                drop_last=False, pin_memory=True)
    
    if is_training:
        if args.ood_aux_dataset in ['TinyImages', 'ImExtra', 'Texture']:
            if args.dataset in ['cifar10', 'cifar100']:
                ood_set = Subset(TinyImages(args.data_root_path, transform=train_transform), list(range(args.num_ood_samples)))
            elif args.dataset == 'imagenet':
                ood_set = ImageFolder(osp.join(args.data_root_path, 'imagenet/extra_1k'), transform=train_transform)
            elif args.dataset == 'waterbird':
                ood_set = ImageFolder(osp.join(args.data_root_path, 'texture'), transform=train_transform)
            else:
                raise NotImplementedError(args.dataset, args.ood_aux_dataset)
            
            if args.ddp:
                ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set)
            else:
                ood_sampler = None
            ood_loader = DataLoader(ood_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
                                        drop_last=True, pin_memory=True, sampler=ood_sampler)
            ood_num = len(ood_set)
            print('Training on %s with %d images and %d validation images | %d OOD training images.' % (args.dataset, len(train_set), len(test_set), len(ood_set)))
        elif args.ood_aux_dataset in ['VOS', 'NPOS']:
            # sample_num = min(train_set.img_num_per_cls) * 10
            sample_num = max(train_set.img_num_per_cls)
            feat_dim = 512 if 'cifar' in args.dataset else 1024  # TODO: check
            mode = args.ood_aux_dataset
            device = 'cuda:0'
            ood_loader = IDFeatPool(num_classes, sample_num=sample_num, feat_dim=feat_dim, mode=mode, device=device)
            ood_num = num_classes * sample_num
        elif args.ood_aux_dataset in ['CIFAR']:
            if args.dataset == 'cifar10':
                dout = 'cifar100'
            elif args.dataset == 'cifar100':
                dout = 'cifar10'
            ood_set = ImageFolder(osp.join(args.data_root_path, f'SCOOD/data/images/{dout}/test'), transform=train_transform)
            if args.ddp:
                ood_sampler = torch.utils.data.distributed.DistributedSampler(ood_set)
            else:
                ood_sampler = None
            ood_loader = DataLoader(ood_set, batch_size=train_batch_size, shuffle=not args.ddp, num_workers=num_workers,
                                        drop_last=True, pin_memory=True, sampler=ood_sampler)
            ood_num = len(ood_set)
            print('Training on %s with %d images and %d validation images | %d OOD training images.' % (args.dataset, len(train_set), len(test_set), len(ood_set)))
        else:
            raise NotImplementedError(f'{args.dataset} v.s. {args.ood_aux_dataset}')

        img_num_per_cls_and_ood = np.array(train_set.img_num_per_cls + [ood_num])

        return num_classes, train_loader, test_loader, ood_loader, train_sampler, img_num_per_cls_and_ood
    else:

        img_num_per_cls_and_ood = np.array(train_set.img_num_per_cls + [args.num_ood_samples])

        return num_classes, test_loader, img_num_per_cls_and_ood