community-content/vertex_model_garden/model_oss/keras/serve.py (133 lines of code) (raw):

r"""Servers Keras Stable Diffusion models. python serve.py --model_path=<model path in gcs> curl -d \ '{"prompt":"Hello Kitty"}' \ -H "Content-Type: application/json" \ -X POST http://localhost:8501/predict """ import base64 import io import json import os from typing import List, Tuple from absl import app # The docker builds could not find flask and waitress. # pylint: disable=import-error from flask import Flask from flask import request from flask import Response import keras_cv from PIL import Image from waitress import serve from util import constants from util import fileutils flask_app = Flask(__name__) stable_diffusion_model = None model_path = os.environ.get('MODEL_PATH', '') if model_path.startswith(constants.GCS_URI_PREFIX): print('Downloading models from gcs to local.') os.makedirs(constants.LOCAL_MODEL_DIR, exist_ok=True) fileutils.download_gcs_dir_to_local( os.path.dirname(model_path), constants.LOCAL_MODEL_DIR ) model_path = os.path.join( constants.LOCAL_MODEL_DIR, os.path.basename(model_path) ) image_width = int(os.environ.get('IMAGE_WIDTH', 512)) image_height = int(os.environ.get('IMAGE_HEIGHT', 512)) print('image_width=', image_width, 'image_height=', image_height) print('Create Keras stable diffusion models.') stable_diffusion_model = keras_cv.models.StableDiffusion( img_width=image_width, img_height=image_height, jit_compile=True, ) if model_path: # We just reload the weights of the fine-tuned diffusion model. print('Initialize finetuned models from: ', model_path) stable_diffusion_model.diffusion_model.load_weights(model_path) def error(message: str) -> str: """Returns a JSON representing an error response.""" return json.dumps({ 'success': False, 'error': message, }) def check_key_in_json(content: str, keys: List[str]) -> str: for key in keys: if key not in content: return error('No {} in request {}.'.format(key, content)) return None def validate_json_key(json_key_string: str) -> Tuple[str, bool]: try: json_key = json.loads(json_key_string) except (ValueError, TypeError): return (error('Invalid key found in request'), False) return (json_key, True) # The health check route is required for docker deployment in google cloud. @flask_app.route('/ping') def ping() -> Response: """Health checks.""" return Response(status=200) # The return should be `Response` for docker deployment in google cloud. @flask_app.route('/predict', methods=['GET', 'POST']) def predict_model() -> Response: """Predictions.""" if request.method == 'POST': contents = request.get_json(force=True) print('The input contents are:', contents) batch_size = 1 num_steps = 25 seed = 1234 if 'parameters' in contents: parameters = contents['parameters'] if 'batch_size' in parameters: batch_size = int(parameters['batch_size']) if 'num_steps' in parameters: num_steps = int(parameters['num_steps']) if 'seed' in parameters: seed = int(parameters['seed']) print('batch_size=', batch_size, 'num_steps=', num_steps, 'seed=', seed) if batch_size < 1: return Response( response=error('The batch size must be a positive integar.'), status=200, mimetype='text/plain', ) if num_steps < 1: return Response( response=error('The num steps must be a positive integar.'), status=200, mimetype='text/plain', ) predictions = [] for content in contents['instances']: print('Processing:', content) prompt = content['prompt'] generated_image_array = stable_diffusion_model.text_to_image( prompt=prompt, batch_size=batch_size, num_steps=num_steps, seed=seed, ) generated_image_bytes_array = [] for i in range(batch_size): generated_image = Image.fromarray(generated_image_array[i]) # Converts the image to a base64-encoded string. buffered_image = io.BytesIO() generated_image.save(buffered_image, format='JPEG') generated_image_bytes = base64.b64encode( buffered_image.getvalue() ).decode('utf-8') generated_image_bytes_array.append(generated_image_bytes) prediction = { 'prompt': prompt, 'predicted_image': generated_image_bytes_array, } predictions.append(prediction) return Response( response=json.dumps({ 'success': True, 'predictions': predictions, }), status=200, mimetype='text/plain', ) else: return Response( response=json.dumps({ 'success': True, 'isalive': stable_diffusion_model is not None, }), status=200, mimetype='text/plain', ) def serve_main(unused_argv): """The main function to serve Keras models.""" del unused_argv # This is used when running locally only. When deploying to Google App # Engine, a webserver process such as Gunicorn will serve the app. # # Debug deployment. # flask_app.run(host='0.0.0.0', port=8501, debug=True) # Prod deployment. serve(flask_app, host='0.0.0.0', port=8501) if __name__ == '__main__': app.run(serve_main)