def main()

in data_preparation/generate_dataset.py [0:0]


def main(args):
    print("Loading initial dataset and the Cartoonizer model...")
    dataset = load_dataset(args.dataset_id, args.max_num_samples)
    concrete_fn = model_utils.load_model(args.model_id)
    inference_fn = model_utils.perform_inference(concrete_fn)

    print("Preparing the image pairs...")
    os.makedirs(args.data_root, exist_ok=True)
    for sample in tqdm(dataset.as_numpy_iterator()):
        original_image = sample["image"]
        cartoonized_image = inference_fn(original_image)

        hash_image = hashlib.sha1(original_image.tobytes()).hexdigest()
        sample_dir = os.path.join(args.data_root, hash_image)
        os.makedirs(sample_dir)

        original_image = Image.fromarray(original_image).convert("RGB")
        original_image.save(os.path.join(sample_dir, "original_image.png"))
        cartoonized_image.save(os.path.join(sample_dir, "cartoonized_image.png"))

    print(f"Total generated image-pairs: {len(os.listdir(args.data_root))}.")