def transform_one()

in distilvit/curate.py [0:0]


    def transform_one(self, caption):
        if self.model is None:
            self.load_model_and_tokenizer()

        for i, prompt in enumerate(PROMPTS):
            try:
                messages = [
                    {"role": "user", "content": prompt + caption},
                ]

                inputs = self.tokenizer.apply_chat_template(
                    messages, add_generation_prompt=True, return_tensors="pt"
                ).to(self.device)

                with torch.no_grad():
                    outputs = self.model.generate(
                        inputs,
                        max_new_tokens=120,
                        no_repeat_ngram_size=2,
                        repetition_penalty=1.2,
                        num_beams=3,
                        early_stopping=True,
                    )

                result = self.tokenizer.decode(
                    outputs[0][inputs[0].size().numel() :], skip_special_tokens=True
                )
                result = self.extract_text_with_backticks(result)
                result = result.split("\n")[0].strip()
                if self.args.debug:
                    print(f"step {i}: {caption} -> {result}")
                caption = result
            except Exception as e:
                print(f"Failed to process {caption}: {e}")
                return caption

        return caption