def send_request()

in tutorials-and-examples/tpu-examples/single-host-inference/jax/stable-diffusion/stable_diffusion_request.py [0:0]


def send_request(server_ip, prompt="Painting of a squirrel skating in New York"):
  logging.info("Establish the gRPC connection with the model server.")
  _PREDICTION_SERVICE_HOST = str(server_ip)
  _GRPC_PORT = 8500
  options = [
      ("grpc.max_send_message_length", 512 * 1024 * 1024),
      ("grpc.max_receive_message_length", 512 * 1024 * 1024),
  ]
  channel = grpc.insecure_channel(
      f"{_PREDICTION_SERVICE_HOST}:{_GRPC_PORT}", options=options
  )
  stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

  tokenizer = AutoTokenizer.from_pretrained(
      "CompVis/stable-diffusion-v1-4", subfolder="tokenizer", revision="bf16"
  )
  logging.info(f'The prompt is "{prompt}".')
  logging.info("Tokenize the prompt.")
  inputs = dict()
  inputs["prompt_ids"] = tokenizer(
      prompt,
      padding="max_length",
      max_length=tokenizer.model_max_length,
      truncation=True,
      return_tensors="tf",
  ).input_ids

  request = predict_pb2.PredictRequest()
  request.model_spec.name = "stable_diffusion"
  request.model_spec.signature_name = (
      tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  )
  for key, val in inputs.items():
    request.inputs[key].MergeFrom(tf.make_tensor_proto(val))
  logging.info("Send the request to the model server.")
  res = stub.Predict(request)
  logging.info("Predict completed.")
  outputs = {
      name: tf.io.parse_tensor(serialized.SerializeToString(), serialized.dtype)
      for name, serialized in res.outputs.items()
  }
  image = outputs["output_0"].numpy()
  image = image.reshape(image.shape[1:])
  image = (image * 255).round().astype("uint8")
  pil_image = Image.fromarray(image)
  image_file = "stable_diffusion_images.jpg"
  pil_image = pil_image.save(image_file)
  logging.info(f'The image was saved as "{image_file}"')