distilvit/curate.py (157 lines of code) (raw):

""" Using Llama 3 8B Instructto transform captions from the flickr30k dataset. """ import re import platform import torch import argparse from transformers.utils import logging import readability logging.set_verbosity_error() DATASET_NAME = "nlphuji/flickr30k" MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" BATCH_SIZE = platform.system() == "Darwin" and 1 or 10 PROMPT_1 = """ Rewrite the text to be inclusive and free of bias: - Remove gendered pronouns and names, but not for animals. - Remove ethnic, racial, and religious markers. - Maintain the order and relationship of descriptive elements without changing verbs. - Keep the sentence structure as close to the original as possible. - Wrap the result in triple backticks. """ PROMPT_2 = """ Rewrite the text to: - Maintain original verbs for a casual tone. - Use singular forms when the original text describes one person. - Keep the sentence structure as close to the original as possible. - Wrap the result in triple backticks. """ PROMPT_3 = """ Rewrite the text to use noun phrases for brevity and simplicity: - Convert sentences to noun phrases where possible: 'a person is walking' becomes 'a person walking'. - Maintain the order and relationship of descriptive elements without changing verbs. - Avoid adding new verbs or altering the original ones. - Match the original sentence length. - Wrap the result in triple backticks. """ PROMPT_4 = """ Rewrite the text to: - Avoid adding new verbs or altering the original ones. - Wrap the result in triple backticks. """ PROMPTS = [PROMPT_1, PROMPT_3] class TextConverter: def __init__(self, args, model_name=MODEL_NAME): if torch.cuda.is_available(): device = torch.device("cuda") print("Using CUDA (Nvidia GPU).") elif torch.backends.mps.is_available(): device = torch.device("mps") print("Using MPS (Apple Silicon GPU).") else: device = torch.device("cpu") print("Using CPU.") self.device = device self.model = None self.model_name = model_name self.args = args def load_model_and_tokenizer(self): if platform.system() == "Darwin": kw = { "torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True, "trust_remote_code": True, } bnb_config = None else: from transformers import BitsAndBytesConfig kw = { "trust_remote_code": True, } bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) from transformers import AutoModelForCausalLM, AutoTokenizer self.model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto", quantization_config=bnb_config, **kw ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) def process_batch(self, batch): if self.model is None: self.load_model_and_tokenizer() # need to re-triage the original captions with the new order batch["original_caption"] = list(batch["caption"]) batch["original_sentids"] = list(batch["sentids"]) new_captions = [] grades = [] sentids = [] for captions, nsentids in zip(batch["caption"], batch["sentids"]): converted, grade, nsentids = self.transform(captions, nsentids) new_captions.append(converted) grades.append(grade) sentids.append(nsentids) batch["caption"] = new_captions batch["grade"] = grades batch["sentids"] = sentids return batch def by_grade(self, item): return item[1]["DaleChallIndex"] def extract_text_with_backticks(self, input_string): pattern = r"```(.*?)```" match = re.search(pattern, input_string, re.DOTALL) if match is None: return input_string res = match.group(1).strip() if self.args.debug: print(f"original:\n{input_string}\nbacktick extracted:\n{res}\n") return res def transform(self, captions, sentids): transformed_captions = [] for caption, sentid in zip(captions, sentids): result = self.transform_one(caption) try: grade = dict( readability.getmeasures(result, lang="en")["readability grades"] ) except Exception as e: grade = {"DaleChallIndex": 10.0} print(f"{caption} -> {result} with {grade['DaleChallIndex']:.2f}") transformed_captions.append((result, grade, sentid)) transformed_captions.sort(key=self.by_grade) return list(zip(*transformed_captions)) def transform_one(self, caption): if self.model is None: self.load_model_and_tokenizer() for i, prompt in enumerate(PROMPTS): try: messages = [ {"role": "user", "content": prompt + caption}, ] inputs = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model.generate( inputs, max_new_tokens=120, no_repeat_ngram_size=2, repetition_penalty=1.2, num_beams=3, early_stopping=True, ) result = self.tokenizer.decode( outputs[0][inputs[0].size().numel() :], skip_special_tokens=True ) result = self.extract_text_with_backticks(result) result = result.split("\n")[0].strip() if self.args.debug: print(f"step {i}: {caption} -> {result}") caption = result except Exception as e: print(f"Failed to process {caption}: {e}") return caption return caption def main(args): llm_converter = TextConverter(args) if args.text: result = llm_converter.transform_one(args.text) print(f"Transformed Text: {result}") else: from datasets import load_dataset, DatasetDict split = "test[:100]" if args.test_sample else "test" dataset = load_dataset(DATASET_NAME, split=split) num_proc = platform.system() == "Darwin" and 1 or 4 dataset = dataset.map( llm_converter.process_batch, batched=True, batch_size=BATCH_SIZE, num_proc=num_proc, ) dataset = dataset.rename_column("original_caption", "original_alt_text") dataset = dataset.rename_column("caption", "alt_text") dataset_dict = DatasetDict({"test": dataset}) dataset_dict.save_to_disk("./dataset") if not args.test_sample: dataset_dict.push_to_hub("mozilla/flickr30k-transformed-captions") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process some text.") parser.add_argument("--debug", action="store_true", help="Enable debug mode") parser.add_argument("--text", type=str, help="Text to transform") parser.add_argument("--test_sample", action="store_true", help="Run a test sample") args = parser.parse_args() main(args)