docker_images/diffusers/app/pipelines/text_to_image.py (142 lines of code) (raw):
import importlib
import json
import logging
import os
from typing import TYPE_CHECKING
import torch
from app import idle, lora, offline, timing, validation
from app.pipelines import Pipeline
from diffusers import (
AutoencoderKL,
AutoPipelineForText2Image,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
)
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from PIL import Image
class TextToImagePipeline(
Pipeline, lora.LoRAPipelineMixin, offline.OfflineBestEffortMixin
):
def __init__(self, model_id: str):
self.current_lora_adapter = None
self.model_id = None
self.current_tokens_loaded = 0
self.use_auth_token = os.getenv("HF_API_TOKEN")
# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
model_data = self._hub_model_info(model_id)
kwargs = (
{"safety_checker": None}
if model_id.startswith("hf-internal-testing/")
else {}
)
env_dtype = os.getenv("TORCH_DTYPE")
if env_dtype:
kwargs["torch_dtype"] = getattr(torch, env_dtype)
elif torch.cuda.is_available():
kwargs["torch_dtype"] = torch.float16
has_model_index = any(
file.rfilename == "model_index.json" for file in model_data.siblings
)
if self._is_lora(model_data):
model_type = "LoraModel"
elif has_model_index:
config_file = self._hub_repo_file(model_id, "model_index.json")
with open(config_file, "r") as f:
config_dict = json.load(f)
model_type = config_dict.get("_class_name", None)
else:
raise ValueError("Model type not found")
if model_type == "LoraModel":
model_to_load = model_data.cardData["base_model"]
self.model_id = model_to_load
if not model_to_load:
raise ValueError(
"No `base_model` found. Please include a `base_model` on your README.md tags"
)
self._load_sd_with_sdxl_fix(model_to_load, **kwargs)
# The lora will actually be lazily loaded on the fly per request
self.current_lora_adapter = None
else:
if model_id == "stabilityai/stable-diffusion-xl-base-1.0":
self._load_sd_with_sdxl_fix(model_id, **kwargs)
else:
self.ldm = AutoPipelineForText2Image.from_pretrained(
model_id, use_auth_token=self.use_auth_token, **kwargs
)
self.model_id = model_id
self.is_karras_compatible = (
self.ldm.__class__.__init__.__annotations__.get("scheduler", None)
== KarrasDiffusionSchedulers
)
if self.is_karras_compatible:
self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.ldm.scheduler.config
)
self.default_scheduler = self.ldm.scheduler
if not idle.UNLOAD_IDLE:
self._model_to_gpu()
def _load_sd_with_sdxl_fix(self, model_id, **kwargs):
if model_id == "stabilityai/stable-diffusion-xl-base-1.0":
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16, # load fp16 fix VAE
)
kwargs["vae"] = vae
kwargs["variant"] = "fp16"
self.ldm = DiffusionPipeline.from_pretrained(
model_id, use_auth_token=self.use_auth_token, **kwargs
)
@timing.timing
def _model_to_gpu(self):
if torch.cuda.is_available():
self.ldm.to("cuda")
def __call__(self, inputs: str, **kwargs) -> "Image.Image":
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image.Image` with the raw image representation as PIL.
"""
# Check if users set a custom scheduler and pop if from the kwargs if so
custom_scheduler = None
if "scheduler" in kwargs:
custom_scheduler = kwargs["scheduler"]
kwargs.pop("scheduler")
if custom_scheduler:
compatibles = self.ldm.scheduler.compatibles
# Check if the scheduler is compatible
is_compatible_scheduler = [
cls for cls in compatibles if cls.__name__ == custom_scheduler
]
# In case of a compatible scheduler, swap to that for inference
if is_compatible_scheduler:
# Import the scheduler dynamically
SchedulerClass = getattr(
importlib.import_module("diffusers.schedulers"), custom_scheduler
)
self.ldm.scheduler = SchedulerClass.from_config(
self.ldm.scheduler.config
)
else:
logger.info("%s scheduler not loaded: incompatible", custom_scheduler)
self.ldm.scheduler = self.default_scheduler
else:
self.ldm.scheduler = self.default_scheduler
self._load_lora_adapter(kwargs)
if idle.UNLOAD_IDLE:
with idle.request_witnesses():
self._model_to_gpu()
resp = self._process_req(inputs, **kwargs)
else:
resp = self._process_req(inputs, **kwargs)
return resp
def _process_req(self, inputs, **kwargs):
# only one image per prompt is supported
kwargs["num_images_per_prompt"] = 1
if "num_inference_steps" not in kwargs:
default_num_steps = os.getenv("DEFAULT_NUM_INFERENCE_STEPS")
if default_num_steps:
kwargs["num_inference_steps"] = int(default_num_steps)
elif self.is_karras_compatible:
kwargs["num_inference_steps"] = 20
# Else, don't specify anything, leave the default behaviour
if "guidance_scale" not in kwargs:
default_guidance_scale = os.getenv("DEFAULT_GUIDANCE_SCALE")
if default_guidance_scale is not None:
kwargs["guidance_scale"] = float(default_guidance_scale)
# Else, don't specify anything, leave the default behaviour
if "seed" in kwargs:
seed = int(kwargs["seed"])
generator = torch.Generator().manual_seed(seed)
kwargs["generator"] = generator
kwargs.pop("seed")
images = self.ldm(inputs, **kwargs)["images"]
return images[0]