distilvit/curate_api.py (57 lines of code) (raw):

""" Using Llama 3 70B to transform captions from the flickr30k dataset. """ import platform import requests import re from datasets import load_dataset, DatasetDict PROMPT = """ Look at the 5 variations of alt text that describe an image, to create a single one. You will make it inclusive and eliminate gendered language, racism, sexism, ageism, and ableism: - Remove any bias or stereotypes from the text. - Convert sentences to noun phrases where possible - Keep animal descriptions intact. For example, 'a black dog' should remain 'a black dog' and not 'a dog'. - Remove any ethnic, racial, or religious markers from the text. - If there's a mention of a girl or boy replace it with 'child' or 'kid' - The output should be a single sentence and its length should be close to the original text. - Avoid changing original verbs to maintain the casual and conversational tone of the text. - Prefer the word `person` over `individual`. - The text should be understandable by an 8 years old. Use the simplest words possible. - Try not to lose details in the description but keep it as concise as possible - Do not try to describe the scene; focus on just rewriting the text as instructed. - Wrap the result between triple backticks %s """ DATASET_NAME = "nlphuji/flickr30k" BATCH_SIZE = 25 class LLMService: def __init__(self, model, url="http://10.0.0.40:8080"): self.base_url = url self.model = model def generate_completion(self, prompt): url = f"{self.base_url}/api/generate" headers = {"Content-Type": "application/json"} data = {"model": self.model, "prompt": prompt, "stream": False} response = requests.post(url, headers=headers, json=data) if response.status_code == 200: return response.json()["response"] else: return f"Error: {response.status_code}, {response.text}" def extract_text_with_backticks(self, input_string): pattern = r"```(.*?)```" match = re.search(pattern, input_string, re.DOTALL) if match is None: return input_string res = match.group(1).strip() return res def process_caption(self, caption): try: return self.extract_text_with_backticks( self.generate_completion(PROMPT % str(caption)) ) except Exception as e: print(f"Error: {e}") return caption[0] def process_batch(self, batch): batch["original_caption"] = list(batch["caption"]) new_captions = [] for caption in batch["caption"]: new_captions.append([self.process_caption(caption)]) batch["caption"] = new_captions return batch if __name__ == "__main__": # num_proc = platform.system() == "Darwin" and 4 or 8 service = LLMService("llama3:70b", "http://10.0.0.40:8282") split = "test" dataset = load_dataset(DATASET_NAME, split=split) # dataset.cleanup_cache_files() dataset = dataset.map( service.process_batch, batched=True, batch_size=BATCH_SIZE, num_proc=1, ) dataset = dataset.rename_column("original_caption", "original_alt_text") dataset = dataset.rename_column("caption", "alt_text") dataset_dict = DatasetDict({"test": dataset}) dataset_dict.save_to_disk("./dataset") dataset_dict.push_to_hub("mozilla/flickr30k-transformed-captions")