in community-content/vertex_model_garden/model_oss/keras/serve.py [0:0]
def predict_model() -> Response:
"""Predictions."""
if request.method == 'POST':
contents = request.get_json(force=True)
print('The input contents are:', contents)
batch_size = 1
num_steps = 25
seed = 1234
if 'parameters' in contents:
parameters = contents['parameters']
if 'batch_size' in parameters:
batch_size = int(parameters['batch_size'])
if 'num_steps' in parameters:
num_steps = int(parameters['num_steps'])
if 'seed' in parameters:
seed = int(parameters['seed'])
print('batch_size=', batch_size, 'num_steps=', num_steps, 'seed=', seed)
if batch_size < 1:
return Response(
response=error('The batch size must be a positive integar.'),
status=200,
mimetype='text/plain',
)
if num_steps < 1:
return Response(
response=error('The num steps must be a positive integar.'),
status=200,
mimetype='text/plain',
)
predictions = []
for content in contents['instances']:
print('Processing:', content)
prompt = content['prompt']
generated_image_array = stable_diffusion_model.text_to_image(
prompt=prompt,
batch_size=batch_size,
num_steps=num_steps,
seed=seed,
)
generated_image_bytes_array = []
for i in range(batch_size):
generated_image = Image.fromarray(generated_image_array[i])
# Converts the image to a base64-encoded string.
buffered_image = io.BytesIO()
generated_image.save(buffered_image, format='JPEG')
generated_image_bytes = base64.b64encode(
buffered_image.getvalue()
).decode('utf-8')
generated_image_bytes_array.append(generated_image_bytes)
prediction = {
'prompt': prompt,
'predicted_image': generated_image_bytes_array,
}
predictions.append(prediction)
return Response(
response=json.dumps({
'success': True,
'predictions': predictions,
}),
status=200,
mimetype='text/plain',
)
else:
return Response(
response=json.dumps({
'success': True,
'isalive': stable_diffusion_model is not None,
}),
status=200,
mimetype='text/plain',
)