in scripts/convert_imagenet_to_wds.py [0:0]
def convert_imagenet_to_wds(output_dir, max_train_samples_per_shard, max_val_samples_per_shard):
assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar"))
assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar"))
opat = os.path.join(output_dir, "imagenet-train-%06d.tar")
output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard)
dataset = load_dataset("imagenet-1k", streaming=True, split="train", use_auth_token=True)
now = time.time()
for i, example in enumerate(dataset):
if i % max_train_samples_per_shard == 0:
print(i, file=sys.stderr)
img, label = example["image"], example["label"]
output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
output.close()
time_taken = time.time() - now
print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.")
opat = os.path.join(output_dir, "imagenet-val-%06d.tar")
output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard)
dataset = load_dataset("imagenet-1k", streaming=True, split="validation", use_auth_token=True)
now = time.time()
for i, example in enumerate(dataset):
if i % max_val_samples_per_shard == 0:
print(i, file=sys.stderr)
img, label = example["image"], example["label"]
output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
output.close()
time_taken = time.time() - now
print(f"Wrote {i+1} val examples in {time_taken // 60} min.")