docker_images/latent-to-image/app/pipelines/latent_to_image.py (87 lines of code) (raw):
import logging
import os
from typing import TYPE_CHECKING
import torch
from app import idle, offline, timing, validation
from app.pipelines import Pipeline
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from PIL import Image
class LatentToImagePipeline(Pipeline, offline.OfflineBestEffortMixin):
def __init__(self, model_id: str):
self.model_id = None
self.current_tokens_loaded = 0
self.use_auth_token = os.getenv("HF_API_TOKEN")
# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
model_data = self._hub_model_info(model_id)
kwargs = {}
env_dtype = os.getenv("TORCH_DTYPE", "float32")
if env_dtype:
kwargs["torch_dtype"] = getattr(torch, env_dtype)
elif torch.cuda.is_available():
kwargs["torch_dtype"] = torch.float16
has_model_index = any(
file.rfilename == "model_index.json" for file in model_data.siblings
)
if has_model_index:
kwargs["subfolder"] = "vae"
self.vae = AutoencoderKL.from_pretrained(model_id, **kwargs).eval()
self.dtype = kwargs["torch_dtype"]
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if not idle.UNLOAD_IDLE:
self._model_to_gpu()
@timing.timing
def _model_to_gpu(self):
if torch.cuda.is_available():
self.vae.to("cuda")
def __call__(self, inputs: torch.Tensor, **kwargs) -> "Image.Image":
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image.Image` with the raw image representation as PIL.
"""
if idle.UNLOAD_IDLE:
with idle.request_witnesses():
self._model_to_gpu()
resp = self._process_req(inputs, **kwargs)
else:
resp = self._process_req(inputs, **kwargs)
return resp
def _process_req(self, inputs, **kwargs):
needs_upcasting = (
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
)
if needs_upcasting:
self.vae = self.vae.to(torch.float32)
inputs = inputs.to(self.device, torch.float32)
else:
inputs = inputs.to(self.device, self.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = (
hasattr(self.vae.config, "latents_mean")
and self.vae.config.latents_mean is not None
)
has_latents_std = (
hasattr(self.vae.config, "latents_std")
and self.vae.config.latents_std is not None
)
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, 4, 1, 1)
.to(inputs.device, inputs.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, 4, 1, 1)
.to(inputs.device, inputs.dtype)
)
inputs = (
inputs * latents_std / self.vae.config.scaling_factor + latents_mean
)
else:
inputs = inputs / self.vae.config.scaling_factor
with torch.no_grad():
image = self.vae.decode(inputs, return_dict=False)[0]
if needs_upcasting:
self.vae.to(dtype=torch.float16)
image = self.image_processor.postprocess(image, output_type="pil")
return image[0]