server/main.py (110 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import html import json import os import vertexai import vertexai.preview.generative_models as generative_models from flask import Flask, jsonify, request from vertexai.preview.vision_models import ImageGenerationModel from vertexai.generative_models import GenerativeModel, Part, FinishReason app = Flask(__name__) MAX_IMAGE_COUNT = 5 VERTEX_MAX_IMAGE_COUNT = 4 PROJECT_ID = "imagenio" LOCATION = "us-central1" vertexai.init(project=PROJECT_ID, location=LOCATION) image_model = ImageGenerationModel.from_pretrained("imagegeneration@006") caption_model = GenerativeModel("gemini-1.5-pro-preview-0409") caption_generation_config = { "max_output_tokens": 8192, "temperature": 1, "top_p": 0.95, } safety_settings = { generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, } @app.route("/imagen", methods=['GET', 'POST', 'OPTIONS']) def get_image(): if request.method == "OPTIONS": # Allows GET requests from any origin with the Content-Type # header and caches preflight response for an 3600s headers = { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST", "Access-Control-Allow-Headers": "Content-Type", "Access-Control-Max-Age": "3600", } return ("", 204, headers) request_json = request.get_json(silent=True) request_args = request.args default_image_prompt = 'a picture of a cute cat jumping' default_description_prompt = 'describe the image' default_image_count = 1 image_prompt = (request_json or request_args).get('image_prompt', default_image_prompt) input_prompt = (request_json or request_args).get('desc_prompt', default_description_prompt) text_prompt = f"""Do this for each image separately: "{html.escape(input_prompt)}". We will call the result of it as the information about an image. Give each image a title. Return the result as a list of objects in json format; each object will correspond one image and the fields for the object will be "title" for the title and "info" for the information.""" image_count = int((request_json or request_args).get('image_count', default_image_count)) if image_count > MAX_IMAGE_COUNT: return ("Invalid image_count. Maximum image count is 5.", 406) try: images = get_images_with_count(image_prompt, image_count) image_strings = [] caption_input = [] for img in images: temp_bytes = img._image_bytes image_strings.append(base64.b64encode(temp_bytes).decode("ascii")) temp_image=Part.from_data( mime_type="image/png", data=temp_bytes) caption_input.append(temp_image) captions = caption_model.generate_content( caption_input + [text_prompt], generation_config=caption_generation_config, safety_settings=safety_settings, ) captions_list = make_captions(captions) except Exception as error: return (jsonify({ "error": str(error) }), 500, { "Access-Control-Allow-Origin": "*" }) resp_images_dict = [] for img, cap in zip(image_strings, captions_list): resp_images_dict.append({"image": img, "caption": cap["description"], "title": cap["title"]}) resp = jsonify(resp_images_dict) resp.headers.set("Access-Control-Allow-Origin", "*") return resp def get_images_with_count(image_prompt, image_count): current_image_count = 0 images = [] while current_image_count < image_count: remaining_image_count = image_count - current_image_count allowed_image_count = min(VERTEX_MAX_IMAGE_COUNT, remaining_image_count) temp_images = image_model.generate_images( prompt=image_prompt, # Optional parameters number_of_images=allowed_image_count, language="en", # You can't use a seed value and watermark at the same time. # add_watermark=False, # seed=100, aspect_ratio="1:1", safety_filter_level="block_some", person_generation="allow_adult", ) images.extend(temp_images) current_image_count = len(images) print(f'Images generated so far: {current_image_count}') return images def make_captions(captions): captions_text = captions.text # Sometimes the result is returned with a json field specifier print(captions_text) if captions_text.startswith("```json"): captions_text = captions_text[7:-4] captions_list = json.loads(captions_text) final_captions = [] for caption in captions_list: title = caption["title"] desc = caption["info"] final_captions.append({"title": title, "description": desc}) return final_captions @app.route("/") def hello_world(): """Example Hello World route.""" name = os.environ.get("NAME", "World") return f"Hello {name}!" if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=int(os.environ.get("PORT", 8080)))