in generate.py [0:0]
def main():
args = parse_args()
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
source = args.checkpoint if args.checkpoint else args.hf_model
print(f"Loading weights from: {source}")
model = VisionLanguageModel.from_pretrained(source).to(device)
model.eval()
tokenizer = get_tokenizer(model.cfg.lm_tokenizer, model.cfg.vlm_extra_tokens)
image_processor = get_image_processor(model.cfg.vit_img_size)
messages = [{"role": "user", "content": tokenizer.image_token * model.cfg.mp_image_token_length + args.prompt}]
encoded_prompt = tokenizer.apply_chat_template([messages], tokenize=True, add_generation_prompt=True)
tokens = torch.tensor(encoded_prompt).to(device)
img = Image.open(args.image).convert("RGB")
img_t = image_processor(img).unsqueeze(0).to(device)
print("\nInput:\n ", args.prompt, "\n\nOutputs:")
for i in range(args.generations):
gen = model.generate(tokens, img_t, max_new_tokens=args.max_new_tokens)
out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
print(f" >> Generation {i+1}: {out}")