def main()

in imagenet_c_bar/make_imagenet_c_bar.py [0:0]


def main():
    args = parser.parse_args()
    dataset_path = args.imagenet_dir
    corruption_file = args.corruption_file
    out_dir = os.path.join(args.out_dir, 'ImageNet-C-Bar')
    np.random.seed(args.seed)
    bs = args.batch_size
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    file_dir = os.path.dirname(os.path.realpath(__file__))
    corruption_csv = os.path.join(file_dir, corruption_file)
    corruptions = read_corruption_csv(corruption_csv)

    for name, severities in corruptions.items():
        corruption_dir = os.path.join(out_dir, name)
        if not os.path.exists(corruption_dir):
            os.mkdir(corruption_dir)
        for i, severity in enumerate(severities):
            severity_dir = os.path.join(corruption_dir, "{:.2f}".format(severity))
            if not os.path.exists(severity_dir):
                os.mkdir(severity_dir)
            print("Starting {}-{:.2f}...".format(name, severity))
            transform = tv.transforms.Compose([
                tv.transforms.Resize(256),
                tv.transforms.CenterCrop(224),
                PilToNumpy(),
                build_transform(name=name, severity=severity, dataset_type='imagenet'),
                ])
            path = os.path.join(dataset_path, 'val')
            dataset = SavingDataset(path, severity_dir, transform=transform)
            loader = torch.utils.data.DataLoader(
                    dataset,
                    shuffle=False,
                    sampler=None,
                    drop_last=False,
                    pin_memory=False,
                    num_workers=args.num_workers,
                    batch_size=bs
                    )
            for j, (im, label) in enumerate(loader):
                if (j+1) % 10 == 0:
                    print("Completed {}/{}".format(j, len(loader)))