in imagenet_c_bar/make_cifar10_c_bar.py [0:0]
def main():
args = parser.parse_args()
dataset_path = args.cifar_dir
out_dir = os.path.join(args.out_dir, 'CIFAR-10-C-Bar')
bs = args.batch_size
if not os.path.exists(out_dir):
os.mkdir(out_dir)
file_dir = os.path.dirname(os.path.realpath(__file__))
corruption_csv = os.path.join(file_dir, 'cifar10_c_bar.csv')
corruptions = read_corruption_csv(corruption_csv)
for name, severities in corruptions.items():
data = np.zeros((len(severities)*10000, 32, 32, 3)).astype(np.uint8)
labels = np.zeros(len(severities)*10000).astype(np.int)
for i, severity in enumerate(severities):
print("Starting {}-{:.2f}...".format(name, severity))
transform = tv.transforms.Compose([
PilToNumpy(),
build_transform(name=name, severity=severity, dataset_type='cifar'),
])
dataset = tv.datasets.CIFAR10(dataset_path, train=False, download=False, transform=transform)
loader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
sampler=None,
drop_last=False,
pin_memory=False,
num_workers=args.num_workers,
batch_size=bs
)
for j, (im, label) in enumerate(loader):
if im.size(0)==bs:
data[i*10000+j*bs:i*10000+bs*(j+1),:,:,:] = im.numpy().astype(np.uint8)
labels[i*10000+j*bs:i*10000+bs*(j+1)] = label.numpy()
else:
data[i*10000+j:,:,:,:] = im.numpy().astype(np.uint8)
labels[i*10000+j:] = label.numpy()
out_file = os.path.join(out_dir, name + ".npy")
print("Saving {} to {}.".format(name, out_file))
np.save(out_file, data)
labels_file = os.path.join(out_dir, "labels.npy")
np.save(labels_file, labels)