def main()

in imagenet_c_bar/make_cifar10_c_bar.py [0:0]


def main():
    args = parser.parse_args()
    dataset_path = args.cifar_dir
    out_dir = os.path.join(args.out_dir, 'CIFAR-10-C-Bar')
    bs = args.batch_size
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    file_dir = os.path.dirname(os.path.realpath(__file__))
    corruption_csv = os.path.join(file_dir, 'cifar10_c_bar.csv')
    corruptions = read_corruption_csv(corruption_csv)

    for name, severities in corruptions.items():
        data = np.zeros((len(severities)*10000, 32, 32, 3)).astype(np.uint8)
        labels = np.zeros(len(severities)*10000).astype(np.int)
        for i, severity in enumerate(severities):
            print("Starting {}-{:.2f}...".format(name, severity))
            transform = tv.transforms.Compose([
                PilToNumpy(),
                build_transform(name=name, severity=severity, dataset_type='cifar'),
                ])
            dataset = tv.datasets.CIFAR10(dataset_path, train=False, download=False, transform=transform)
            loader = torch.utils.data.DataLoader(
                    dataset,
                    shuffle=False,
                    sampler=None,
                    drop_last=False,
                    pin_memory=False,
                    num_workers=args.num_workers,
                    batch_size=bs
                    )
            for j, (im, label) in enumerate(loader):
                if im.size(0)==bs:
                    data[i*10000+j*bs:i*10000+bs*(j+1),:,:,:] = im.numpy().astype(np.uint8)
                    labels[i*10000+j*bs:i*10000+bs*(j+1)] = label.numpy()
                else:
                    data[i*10000+j:,:,:,:] = im.numpy().astype(np.uint8)
                    labels[i*10000+j:] = label.numpy()

        out_file = os.path.join(out_dir, name + ".npy")
        print("Saving {} to {}.".format(name, out_file))
        np.save(out_file, data)
    
    labels_file = os.path.join(out_dir, "labels.npy")
    np.save(labels_file, labels)