in imagenet_c_bar/make_imagenet_c_bar.py [0:0]
def main():
args = parser.parse_args()
dataset_path = args.imagenet_dir
corruption_file = args.corruption_file
out_dir = os.path.join(args.out_dir, 'ImageNet-C-Bar')
np.random.seed(args.seed)
bs = args.batch_size
if not os.path.exists(out_dir):
os.mkdir(out_dir)
file_dir = os.path.dirname(os.path.realpath(__file__))
corruption_csv = os.path.join(file_dir, corruption_file)
corruptions = read_corruption_csv(corruption_csv)
for name, severities in corruptions.items():
corruption_dir = os.path.join(out_dir, name)
if not os.path.exists(corruption_dir):
os.mkdir(corruption_dir)
for i, severity in enumerate(severities):
severity_dir = os.path.join(corruption_dir, "{:.2f}".format(severity))
if not os.path.exists(severity_dir):
os.mkdir(severity_dir)
print("Starting {}-{:.2f}...".format(name, severity))
transform = tv.transforms.Compose([
tv.transforms.Resize(256),
tv.transforms.CenterCrop(224),
PilToNumpy(),
build_transform(name=name, severity=severity, dataset_type='imagenet'),
])
path = os.path.join(dataset_path, 'val')
dataset = SavingDataset(path, severity_dir, transform=transform)
loader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
sampler=None,
drop_last=False,
pin_memory=False,
num_workers=args.num_workers,
batch_size=bs
)
for j, (im, label) in enumerate(loader):
if (j+1) % 10 == 0:
print("Completed {}/{}".format(j, len(loader)))