in utils/common.py [0:0]
def build_plain_train_loader(args): # for statistic during test
if args.dataset in ['cifar10', 'cifar100']:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
elif args.dataset == 'imagenet':
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
if args.dataset == 'cifar10':
num_classes = 10
train_set = IMBALANCECIFAR10(train=True, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'cifar100':
num_classes = 100
train_set = IMBALANCECIFAR100(train=True, transform=test_transform, imbalance_ratio=args.imbalance_ratio, root=args.data_root_path)
elif args.dataset == 'imagenet':
num_classes = args.id_class_number
train_set = LT_Dataset(
osp.join(args.data_root_path, 'imagenet'), './datasets/ImageNet_LT/ImageNet_LT_train.txt', transform=test_transform,
subset_class_idx=np.arange(0,args.id_class_number))
train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=4,
drop_last=True, pin_memory=True)
return train_loader