in distilvit/curate.py [0:0]
def transform_one(self, caption):
if self.model is None:
self.load_model_and_tokenizer()
for i, prompt in enumerate(PROMPTS):
try:
messages = [
{"role": "user", "content": prompt + caption},
]
inputs = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_new_tokens=120,
no_repeat_ngram_size=2,
repetition_penalty=1.2,
num_beams=3,
early_stopping=True,
)
result = self.tokenizer.decode(
outputs[0][inputs[0].size().numel() :], skip_special_tokens=True
)
result = self.extract_text_with_backticks(result)
result = result.split("\n")[0].strip()
if self.args.debug:
print(f"step {i}: {caption} -> {result}")
caption = result
except Exception as e:
print(f"Failed to process {caption}: {e}")
return caption
return caption