muse/pipeline_muse.py (409 lines of code) (raw):
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Optional, Union, Tuple
import numpy as np
import torch
from PIL import Image
from transformers import (
AutoTokenizer,
CLIPConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
PreTrainedTokenizer,
T5EncoderModel,
)
from .modeling_maskgit_vqgan import MaskGitVQGAN
from .modeling_movq import MOVQ
from .modeling_paella_vq import PaellaVQModel
from .modeling_taming_vqgan import VQGANModel
from .modeling_transformer import MaskGitTransformer, MaskGiTUViT
from .sampling import get_mask_chedule
class PipelineMuse:
def __init__(
self,
vae: Union[VQGANModel, MOVQ, MaskGitVQGAN],
transformer: Union[MaskGitTransformer, MaskGiTUViT],
is_class_conditioned: bool = False,
text_encoder: Optional[Union[T5EncoderModel, CLIPTextModel]] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
) -> None:
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.vae = vae
self.transformer = transformer
self.is_class_conditioned = is_class_conditioned
self.device = "cpu"
def to(self, device="cpu", dtype=torch.float32) -> None:
self.device = device
self.dtype = dtype
if not self.is_class_conditioned:
self.text_encoder.to(device, dtype=dtype)
self.transformer.to(device, dtype=dtype)
self.vae.to(device, dtype=torch.float32) # keep vae in fp32
return self
@torch.no_grad()
def __call__(
self,
text: Optional[Union[str, List[str]]] = None,
negative_text: Optional[Union[str, List[str]]] = "",
prompt_embeds: Optional[torch.Tensor] = None,
pooled_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_embeds: Optional[torch.Tensor] = None,
class_ids: Optional[Union[int, List[int]]] = None,
timesteps: int = 16,
noise_schedule: str = "cosine",
guidance_scale: float = 10.0,
guidance_schedule=None,
temperature: Union[float, Tuple[float]] = (2, 0),
topk_filter_thres: float = 0.9,
num_images_per_prompt: int = 1,
use_maskgit_generate: bool = True,
generator: Optional[torch.Generator] = None,
use_fp16: bool = False,
noise_type="mask", # can be "mask" or "random_replace"
predict_all_tokens=False,
orig_size=(512, 512),
crop_coords=(0, 0),
aesthetic_score=6.0,
return_intermediate: bool = False,
use_tqdm=True,
transformer_seq_len=None,
clip_skip:int = None,
):
if text is None and class_ids is None:
raise ValueError("Either text or class_ids must be provided.")
if text is not None and class_ids is not None:
raise ValueError("Only one of text or class_ids may be provided.")
if class_ids is not None:
if isinstance(class_ids, int):
class_ids = [class_ids]
class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long)
# duplicate class ids for each generation per prompt
class_ids = class_ids.repeat_interleave(num_images_per_prompt, dim=0)
model_inputs = {"class_ids": class_ids}
else:
if isinstance(text, str):
text = [text]
if prompt_embeds is None:
input_ids = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids # TODO: remove hardcode
input_ids = input_ids.to(self.device)
if self.transformer.config.add_cond_embeds:
if prompt_embeds is not None and pooled_embeds is not None:
pooled_embeds, encoder_hidden_states = pooled_embeds, prompt_embeds
pooled_embeds = pooled_embeds.to(self.device, dtype=self.text_encoder.dtype)
encoder_hidden_states = encoder_hidden_states.to(self.device, dtype=self.text_encoder.dtype)
else:
clip_layer_idx = -(clip_skip+1) if clip_skip is not None else -2
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
pooled_embeds, encoder_hidden_states = outputs.text_embeds, outputs.hidden_states[clip_layer_idx]
else:
encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state
pooled_embeds = None
if negative_text is not None:
if isinstance(negative_text, str):
negative_text = [negative_text] * len(text)
negative_input_ids = self.tokenizer(
negative_text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
negative_input_ids = negative_input_ids.to(self.device)
if self.transformer.config.add_cond_embeds:
outputs = self.text_encoder(negative_input_ids, return_dict=True, output_hidden_states=True)
negative_pooled_embeds = outputs.text_embeds
negative_encoder_hidden_states = outputs.hidden_states[-2]
else:
negative_encoder_hidden_states = self.text_encoder(negative_input_ids).last_hidden_state
negative_pooled_embeds = None
elif negative_prompt_embeds is not None:
negative_encoder_hidden_states = negative_prompt_embeds.to(self.device, dtype=self.text_encoder.dtype)
negative_pooled_embeds = negative_pooled_embeds.to(self.device, dtype=self.text_encoder.dtype)
else:
negative_encoder_hidden_states = None
negative_pooled_embeds = None
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
encoder_hidden_states = encoder_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
if pooled_embeds is not None:
bs_embed, _ = pooled_embeds.shape
pooled_embeds = pooled_embeds.repeat(1, num_images_per_prompt)
pooled_embeds = pooled_embeds.view(bs_embed * num_images_per_prompt, -1)
if negative_pooled_embeds is not None:
bs_embed, _ = negative_pooled_embeds.shape
negative_pooled_embeds = negative_pooled_embeds.repeat(1, num_images_per_prompt)
negative_pooled_embeds = negative_pooled_embeds.view(bs_embed * num_images_per_prompt, -1)
if negative_encoder_hidden_states is not None:
bs_embed, seq_len, _ = negative_encoder_hidden_states.shape
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
if negative_encoder_hidden_states is None:
empty_input = self.tokenizer("", padding="max_length", return_tensors="pt").input_ids.to(
self.text_encoder.device
)
outputs = self.text_encoder(empty_input, output_hidden_states=True)
empty_embeds = outputs.hidden_states[-2]
empty_cond_embeds = outputs[0]
else:
empty_embeds, empty_cond_embeds = None, None
model_inputs = {
"encoder_hidden_states": encoder_hidden_states,
"negative_embeds": negative_encoder_hidden_states,
"cond_embeds": pooled_embeds,
"negative_cond_embeds": negative_pooled_embeds,
"empty_embeds": empty_embeds,
"empty_cond_embeds": empty_cond_embeds,
}
if self.transformer.config.add_micro_cond_embeds:
micro_conds = list(orig_size) + list(crop_coords) + [aesthetic_score]
micro_conds = torch.tensor(micro_conds, device=self.device, dtype=encoder_hidden_states.dtype)
micro_conds = micro_conds.unsqueeze(0)
model_inputs["micro_conds"] = micro_conds
generate = self.transformer.generate
if use_maskgit_generate:
generate = self.transformer.generate2
with torch.autocast("cuda", enabled=use_fp16):
outputs = generate(
**model_inputs,
timesteps=timesteps,
guidance_scale=guidance_scale,
guidance_schedule=guidance_schedule,
temperature=temperature,
topk_filter_thres=topk_filter_thres,
generator=generator,
noise_type=noise_type,
noise_schedule=get_mask_chedule(noise_schedule),
predict_all_tokens=predict_all_tokens,
return_intermediate=return_intermediate,
use_tqdm=use_tqdm,
seq_len=transformer_seq_len,
)
if return_intermediate:
generated_tokens, intermediate = outputs
else:
generated_tokens = outputs
images = self.vae.decode_code(generated_tokens)
if return_intermediate:
intermediate_images = [self.vae.decode_code(tokens) for tokens in intermediate]
# Convert to PIL images
images = [self.to_pil_image(image) for image in images]
if return_intermediate:
intermediate_images = [[self.to_pil_image(image) for image in images] for images in intermediate_images]
return images, intermediate_images
return images
def to_pil_image(self, image: torch.Tensor):
image = image.permute(1, 2, 0).cpu().numpy()
image = 2.0 * image - 1.0
image = np.clip(image, -1.0, 1.0)
image = (image + 1.0) / 2.0
image = (255 * image).astype(np.uint8)
image = Image.fromarray(image).convert("RGB")
return image
@classmethod
def from_pretrained(
cls,
model_name_or_path: str = None,
text_encoder_path: Optional[str] = None,
vae_path: Optional[str] = None,
transformer_path: Optional[str] = None,
vae = None,
text_encoder = None,
transformer = None,
is_class_conditioned: bool = False,
) -> None:
"""
Instantiate a PipelineMuse from a pretrained model. Either model_name_or_path or all of text_encoder_path, vae_path, and
transformer_path must be provided.
"""
if model_name_or_path is None:
if text_encoder_path is None or vae_path is None or transformer_path is None:
raise ValueError(
"If model_name_or_path is None, then text_encoder_path, vae_path, and transformer_path must be"
" provided."
)
text_encoder_args = None
tokenizer_args = None
if not is_class_conditioned:
text_encoder_args = {"pretrained_model_name_or_path": text_encoder_path}
tokenizer_args = {"pretrained_model_name_or_path": text_encoder_path}
vae_args = {"pretrained_model_name_or_path": vae_path}
transformer_args = {"pretrained_model_name_or_path": transformer_path}
else:
text_encoder_args = None
tokenizer_args = None
if not is_class_conditioned:
text_encoder_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "text_encoder"}
tokenizer_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "text_encoder"}
vae_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "vae"}
transformer_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "transformer"}
if not is_class_conditioned:
# Very hacky way to load different text encoders
# TODO: Add config for pipeline to specify text encoder
# is_clip = "clip" in text_encoder_args["pretrained_model_name_or_path"].lower()
# text_encoder_cls = CLIPTextModel if is_clip else T5EncoderModel
# if is_clip:
# config = CLIPConfig.from_pretrained(**text_encoder_args)
# if config.architectures[0] == "CLIPTextModel":
# text_encoder_cls = CLIPTextModel
# else:
# text_encoder_cls = CLIPTextModelWithProjection
# text_encoder_args["projection_dim"] = 768
# TODO: make this more robust
if text_encoder is None:
text_encoder = CLIPTextModelWithProjection.from_pretrained(**text_encoder_args)
tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)
transformer_config = MaskGitTransformer.load_config(**transformer_args)
if transformer is not None:
...
elif transformer_config["_class_name"] == "MaskGitTransformer":
transformer = MaskGitTransformer.from_pretrained(**transformer_args)
elif transformer_config["_class_name"] == "MaskGiTUViT" or transformer_config["_class_name"] == "MaskGiTUViT_v2":
transformer = MaskGiTUViT.from_pretrained(**transformer_args)
else:
raise ValueError(f"Unknown Transformer class: {transformer_config['_class_name']}")
# Hacky way to load different VQ models
vae_config = MaskGitVQGAN.load_config(**vae_args)
if vae is not None:
...
elif vae_config["_class_name"] == "VQGANModel":
vae = VQGANModel.from_pretrained(**vae_args)
elif vae_config["_class_name"] == "MaskGitVQGAN":
vae = MaskGitVQGAN.from_pretrained(**vae_args)
elif vae_config["_class_name"] == "MOVQ":
vae = MOVQ.from_pretrained(**vae_args)
elif vae_config["_class_name"] == "PaellaVQModel":
vae = PaellaVQModel.from_pretrained(**vae_args)
else:
raise ValueError(f"Unknown VAE class: {vae_config['_class_name']}")
if is_class_conditioned:
return cls(
vae=vae,
transformer=transformer,
is_class_conditioned=is_class_conditioned,
)
return cls(
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
is_class_conditioned=is_class_conditioned,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
) -> None:
"""
Save the pipeline's model and tokenizer to the specified directory.
"""
if not self.is_class_conditioned:
self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
self.tokenizer.save_pretrained(os.path.join(save_directory, "text_encoder"))
self.vae.save_pretrained(os.path.join(save_directory, "vae"))
self.transformer.save_pretrained(os.path.join(save_directory, "transformer"))
class PipelineMuseInpainting(PipelineMuse):
@torch.no_grad()
def __call__(
self,
image: Image,
mask: torch.BoolTensor,
text: Optional[Union[str, List[str]]] = None,
negative_text: Optional[Union[str, List[str]]] = None,
class_ids: torch.LongTensor = None,
timesteps: int = 8,
guidance_scale: float = 8.0,
guidance_schedule=None,
temperature: float = 1.0,
topk_filter_thres: float = 0.9,
num_images_per_prompt: int = 1,
use_maskgit_generate: bool = True,
generator: Optional[torch.Generator] = None,
use_fp16: bool = False,
image_size: int = 256,
orig_size=(256, 256),
crop_coords=(0, 0),
aesthetic_score=6.0,
):
from torchvision import transforms
assert use_maskgit_generate
if text is None and class_ids is None:
raise ValueError("Either text or class_ids must be provided.")
if text is not None and class_ids is not None:
raise ValueError("Only one of text or class_ids may be provided.")
encode_transform = transforms.Compose(
[
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
)
pixel_values = encode_transform(image).unsqueeze(0).to(self.device)
_, image_tokens = self.vae.encode(pixel_values)
mask_token_id = self.transformer.config.mask_token_id
image_tokens[mask[None]] = mask_token_id
image_tokens = image_tokens.repeat(num_images_per_prompt, 1)
if class_ids is not None:
if isinstance(class_ids, int):
class_ids = [class_ids]
class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long)
# duplicate class ids for each generation per prompt
class_ids = class_ids.repeat_interleave(num_images_per_prompt, dim=0)
model_inputs = {"class_ids": class_ids}
else:
if isinstance(text, str):
text = [text]
input_ids = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids # TODO: remove hardcode
input_ids = input_ids.to(self.device)
if self.transformer.config.add_cond_embeds:
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
pooled_embeds, encoder_hidden_states = outputs.text_embeds, outputs.hidden_states[-2]
else:
encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state
pooled_embeds = None
if negative_text is not None:
if isinstance(negative_text, str):
negative_text = [negative_text]
negative_input_ids = self.tokenizer(
negative_text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
negative_input_ids = negative_input_ids.to(self.device)
negative_encoder_hidden_states = self.text_encoder(negative_input_ids).last_hidden_state
else:
negative_encoder_hidden_states = None
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
encoder_hidden_states = encoder_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
if negative_encoder_hidden_states is not None:
bs_embed, seq_len, _ = negative_encoder_hidden_states.shape
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
empty_input = self.tokenizer("", padding="max_length", return_tensors="pt").input_ids.to(
self.text_encoder.device
)
outputs = self.text_encoder(empty_input, output_hidden_states=True)
empty_embeds = outputs.hidden_states[-2]
empty_cond_embeds = outputs[0]
model_inputs = {
"encoder_hidden_states": encoder_hidden_states,
"negative_embeds": negative_encoder_hidden_states,
"empty_embeds": empty_embeds,
"empty_cond_embeds": empty_cond_embeds,
"cond_embeds": pooled_embeds,
}
if self.transformer.config.add_micro_cond_embeds:
micro_conds = list(orig_size) + list(crop_coords) + [aesthetic_score]
micro_conds = torch.tensor(micro_conds, device=self.device, dtype=encoder_hidden_states.dtype)
micro_conds = micro_conds.unsqueeze(0)
model_inputs["micro_conds"] = micro_conds
generate = self.transformer.generate2
with torch.autocast("cuda", enabled=use_fp16):
generated_tokens = generate(
input_ids=image_tokens,
**model_inputs,
timesteps=timesteps,
guidance_scale=guidance_scale,
guidance_schedule=guidance_schedule,
temperature=temperature,
topk_filter_thres=topk_filter_thres,
generator=generator,
)
images = self.vae.decode_code(generated_tokens)
# Convert to PIL images
images = [self.to_pil_image(image) for image in images]
return images