def generate_tpu()

in ai-ml/gke-ray/rayserve/stable-diffusion/stable_diffusion_tpu.py [0:0]


  def generate_tpu(self, prompts: List[str]):
    """Generates a batch of images from Diffusion from a list of prompts.

    Args:
      prompts: a list of strings. Should be a factor of 4.

    Returns:
      A list of PIL Images.
    """
    from flax.training.common_utils import shard
    import jax
    import numpy as np

    rng = jax.random.PRNGKey(0)
    rng = jax.random.split(rng, jax.device_count())

    assert prompts, "prompt parameter cannot be empty"
    logger.info("Prompts: %s", prompts)
    prompt_ids = self._pipeline.prepare_inputs(prompts)
    prompt_ids = shard(prompt_ids)
    logger.info("Sharded prompt ids has shape: %s", prompt_ids.shape)
    if self._run_with_profiler:
      jax.profiler.start_trace(self._profiler_dir)

    time_start = time.time()
    images = self._p_generate(prompt_ids, self._p_params, rng)
    images = images.block_until_ready()
    elapsed = time.time() - time_start
    if self._run_with_profiler:
      jax.profiler.stop_trace()

    logger.info("Inference time (in seconds): %f", elapsed)
    logger.info("Shape of the predictions: %s", images.shape)
    images = images.reshape(
        (images.shape[0] * images.shape[1],) + images.shape[-3:])
    logger.info("Shape of images afterwards: %s", images.shape)
    return self._pipeline.numpy_to_pil(np.array(images))