in ai-ml/gke-ray/rayserve/stable-diffusion/stable_diffusion_tpu.py [0:0]
def __init__(
self, run_with_profiler: bool = False, warmup: bool = False,
warmup_batch_size: int = _MAX_BATCH_SIZE):
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
import jax
import jax.numpy as jnp
from jax import pmap
model_id = "CompVis/stable-diffusion-v1-4"
self._pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
revision="bf16",
dtype=jnp.bfloat16)
self._p_params = replicate(params)
self._p_generate = pmap(self._pipeline._generate)
self._run_with_profiler = run_with_profiler
self._profiler_dir = "/tmp/tensorboard"
if warmup:
logger.info("Sending warmup requests.")
warmup_prompts = ["A warmup request"] * warmup_batch_size
self.generate_tpu(warmup_prompts)