in ai-ml/gke-ray/rayserve/stable-diffusion/stable_diffusion_tpu.py [0:0]
def generate_tpu(self, prompts: List[str]):
"""Generates a batch of images from Diffusion from a list of prompts.
Args:
prompts: a list of strings. Should be a factor of 4.
Returns:
A list of PIL Images.
"""
from flax.training.common_utils import shard
import jax
import numpy as np
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, jax.device_count())
assert prompts, "prompt parameter cannot be empty"
logger.info("Prompts: %s", prompts)
prompt_ids = self._pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
logger.info("Sharded prompt ids has shape: %s", prompt_ids.shape)
if self._run_with_profiler:
jax.profiler.start_trace(self._profiler_dir)
time_start = time.time()
images = self._p_generate(prompt_ids, self._p_params, rng)
images = images.block_until_ready()
elapsed = time.time() - time_start
if self._run_with_profiler:
jax.profiler.stop_trace()
logger.info("Inference time (in seconds): %f", elapsed)
logger.info("Shape of the predictions: %s", images.shape)
images = images.reshape(
(images.shape[0] * images.shape[1],) + images.shape[-3:])
logger.info("Shape of images afterwards: %s", images.shape)
return self._pipeline.numpy_to_pil(np.array(images))