in data_preparation/generate_dataset.py [0:0]
def main(args):
print("Loading initial dataset and the Cartoonizer model...")
dataset = load_dataset(args.dataset_id, args.max_num_samples)
concrete_fn = model_utils.load_model(args.model_id)
inference_fn = model_utils.perform_inference(concrete_fn)
print("Preparing the image pairs...")
os.makedirs(args.data_root, exist_ok=True)
for sample in tqdm(dataset.as_numpy_iterator()):
original_image = sample["image"]
cartoonized_image = inference_fn(original_image)
hash_image = hashlib.sha1(original_image.tobytes()).hexdigest()
sample_dir = os.path.join(args.data_root, hash_image)
os.makedirs(sample_dir)
original_image = Image.fromarray(original_image).convert("RGB")
original_image.save(os.path.join(sample_dir, "original_image.png"))
cartoonized_image.save(os.path.join(sample_dir, "cartoonized_image.png"))
print(f"Total generated image-pairs: {len(os.listdir(args.data_root))}.")