in distilvit/gpt4.py [0:0]
def __call__(self, batch):
if "image" not in batch:
batch["image"] = [
Image.open(path).convert("RGB") for path in batch["image_path"]
]
img_ids = [str(img_id) for img_id in batch[self.args.image_id_column]]
# Check cache
cached_alt_texts = self.db.get(img_ids)
new_images = [
image
for img_id, image in zip(img_ids, batch[self.args.image_column])
if img_id not in cached_alt_texts
]
new_img_ids = [img_id for img_id in img_ids if img_id not in cached_alt_texts]
if new_images:
generation = self.generate(new_images, new_img_ids)
if "images" not in generation:
generation = {"images": [generation]}
self.db.set(new_img_ids, generation["images"])
cached_alt_texts.update(dict(zip(new_img_ids, generation["images"])))
batch[self.args.generated_alt_text_column] = [
cached_alt_texts[img_id]["alt_text"] for img_id in img_ids
]
batch["objects"] = [
cached_alt_texts[img_id]["detected_objects"] for img_id in img_ids
]
return batch