def __init__()

in ai-ml/gke-ray/rayserve/stable-diffusion/stable_diffusion_tpu.py [0:0]


  def __init__(
      self, run_with_profiler: bool = False, warmup: bool = False,
      warmup_batch_size: int = _MAX_BATCH_SIZE):
    from diffusers import FlaxStableDiffusionPipeline
    from flax.jax_utils import replicate
    import jax 
    import jax.numpy as jnp
    from jax import pmap

    model_id = "CompVis/stable-diffusion-v1-4"

    self._pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
        model_id,
        revision="bf16",
        dtype=jnp.bfloat16)

    self._p_params = replicate(params)
    self._p_generate = pmap(self._pipeline._generate)
    self._run_with_profiler = run_with_profiler
    self._profiler_dir = "/tmp/tensorboard"

    if warmup:
      logger.info("Sending warmup requests.")
      warmup_prompts = ["A warmup request"] * warmup_batch_size
      self.generate_tpu(warmup_prompts)