in src/transformers/models/emu3/convert_emu3_weights_to_hf.py [0:0]
def convert_model(vq_model_id, llm_model_id, output_dir, hub_model_id=None, test_inference=False):
os.makedirs(output_dir, exist_ok=True)
# Convert and save processor
tokenizer_tiktoken = AutoTokenizer.from_pretrained(llm_model_id, trust_remote_code=True)
convert_tiktoken(tokenizer_tiktoken, output_dir)
extra_special_tokens = extra_special_tokens = {
"image_token": "<image>",
"boi_token": "<|image start|>",
"eoi_token": "<|image end|>",
"image_wrapper_token": "<|image token|>",
"eof_token": "<|extra_201|>",
}
tokenizer_converted = AutoTokenizer.from_pretrained(output_dir, extra_special_tokens=extra_special_tokens)
tokenizer_converted.padding_side = "left"
image_processor = Emu3ImageProcessor.from_pretrained(vq_model_id)
processor = Emu3Processor(image_processor, tokenizer_converted, chat_template=CHAT_TEMPLATE)
processor.save_pretrained(output_dir)
# load models
model_llm = AutoModelForCausalLM.from_pretrained(
llm_model_id,
trust_remote_code=True,
)
model_vqgan = AutoModel.from_pretrained(vq_model_id, trust_remote_code=True)
with open(f"{output_dir}/tokenizer.json", "r") as file:
tokenizer_config = json.load(file)
vocabulary_map = tokenizer_config["model"]["vocab"]
text_config = Emu3TextConfig(
max_position_embeddings=model_llm.config.max_position_embeddings,
rope_scaling={"rope_type": "default"},
)
config = Emu3Config(text_config=text_config, vocabulary_map=vocabulary_map)
with init_empty_weights():
model = Emu3ForConditionalGeneration(config=config)
model.generation_config = GenerationConfig(
do_sample=True,
top_k=2048,
max_new_tokens=50_000,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
)
state_dict = {}
state_dict = convert_state_dict_to_hf(model_llm.state_dict(), state_dict)
state_dict = convert_state_dict_to_hf(model_vqgan.state_dict(), state_dict)
model.load_state_dict(state_dict, assign=True, strict=True)
model.save_pretrained(output_dir, safe_serialization=True)
if hub_model_id is not None:
model.push_to_hub(hub_model_id)
processor.push_to_hub(hub_model_id)
if test_inference and llm_model_id.endswith("Chat"):
# Short inference on a few examples to check if generation makes sense
print("Loading the checkpoint in a Emu3 model...")
print("*" * 100)
model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
processor = Emu3Processor.from_pretrained(output_dir)
conversation = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "Please tell me about this art work and its artist."},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image = Image.open(
requests.get(
"https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
).raw
)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
length = inputs.input_ids.shape[1]
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
print(f"Generation for single-image: {generated_text}")
print("*" * 100)
elif test_inference and llm_model_id.endswith("Gen"):
processor = Emu3Processor.from_pretrained(output_dir)
model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
inputs = processor(
text=[
"a portrait of young girl. masterpiece, film grained, best quality.",
"a dog running under the rain",
],
padding=True,
return_tensors="pt",
return_for_image_generation=True,
)
inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)
neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0")
image_sizes = inputs.pop("image_sizes")
HEIGHT, WIDTH = image_sizes[0]
VISUAL_TOKENS = model.vocabulary_mapping.image_tokens
def prefix_allowed_tokens_fn(batch_id, input_ids):
height, width = HEIGHT, WIDTH
visual_tokens = VISUAL_TOKENS
image_token_id = processor.tokenizer.encode("<|image token|>", return_tensors="pt")[0].to(model.device)
eoi_token_id = processor.tokenizer.encode("<|image end|>", return_tensors="pt")[0]
eos_token_id = processor.tokenizer.encode("<|extra_204|>", return_tensors="pt")[0]
pad_token_id = processor.tokenizer.encode("<|endoftext|>", return_tensors="pt")[0]
eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
eof_token_id = processor.tokenizer.encode("<|extra_201|>", return_tensors="pt")[0]
position = torch.nonzero(input_ids == image_token_id, as_tuple=True)[0][0]
offset = input_ids.shape[0] - position
if offset % (width + 1) == 0:
return (eol_token_id,)
elif offset == (width + 1) * height + 1:
return (eof_token_id,)
elif offset == (width + 1) * height + 2:
return (eoi_token_id,)
elif offset == (width + 1) * height + 3:
return (eos_token_id,)
elif offset > (width + 1) * height + 3:
return (pad_token_id,)
else:
return visual_tokens
out = model.generate(
**inputs,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
negative_prompt_ids=neg_inputs.input_ids,
negative_prompt_attention_mask=neg_inputs.attention_mask,
)
image = model.decode_image_tokens(out[:, inputs.input_ids.shape[1] :], height=HEIGHT, width=WIDTH)
images = processor.postprocess(
list(image.float()), return_tensors="PIL.Image.Image"
) # internally we convert to np but it's not supported in bf16 precision
for i, image in enumerate(images["pixel_values"]):
image.save(f"result_{i}.png")