experiments/arena/models/generate.py (188 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generate Images from models in Model Garden or Gemini """ import base64 import io import logging import time from typing import Any import uuid import random import os from PIL import Image from google.cloud import aiplatform from google.cloud.firestore import Client, FieldFilter import vertexai from vertexai.preview.vision_models import ImageGenerationModel from config.default import Default from config.firebase_config import FirebaseClient from common.storage import store_to_gcs from common.metadata import add_image_metadata config = Default() logging.basicConfig(level=logging.DEBUG) def base64_to_image(image_str: str) -> Any: """Convert base64 encoded string to an image. Args: image_str: A string of base64 encoded image. Returns: A PIL.Image instance. """ image = Image.open(io.BytesIO(base64.b64decode(image_str))) return image def generate_images_from_model_garden( prompt: str, endpoint_id: str, model_name: str, output_gcs_folder: str, parameters: dict[str, Any], project_id: str = config.PROJECT_ID, location: str = config.LOCATION, ) -> list[str]: """ Generates images using a specified endpoints of Model Garden deployed models. Args: prompt: The text prompt for image generation. endpoint_id: The Vertex AI Endpoint ID to use. model_name: A descriptive name for the model (for logging/metadata). output_gcs_folder: The subfolder within config.GENMEDIA_BUCKET to store results. parameters: A dictionary of parameters required by the specific model endpoint (e.g., height, width, num_inference_steps). project_id: Google Cloud Project ID. Defaults to config.PROJECT_ID. location: Google Cloud Location. Defaults to config.LOCATION. Returns: A list of GCS URIs for the generated images (e.g., ["gs://bucket/folder/uuid.png"]). Raises: ValueError: If required arguments are missing or invalid. # Re-raises exceptions from aiplatform.Endpoint.predict """ if not all([prompt, endpoint_id, model_name, output_gcs_folder, parameters]): raise ValueError("Missing one or more required arguments: prompt, endpoint_id, model_name, output_gcs_folder, parameters") if not isinstance(parameters, dict): raise ValueError("parameters must be a dictionary") logging.info(f"Generating image with endpoint model: {model_name}") logging.info(f"Prompt: '{prompt}'") logging.info(f"Endpoint ID: {endpoint_id}") logging.info(f"Parameters: {parameters}") logging.info(f"Target GCS Folder: gs://{config.GENMEDIA_BUCKET}/{output_gcs_folder}/") aiplatform.init(project=project_id, location=location) instances = [{"text": prompt}] endpoint_path = f"projects/{project_id}/locations/{location}/endpoints/{endpoint_id}" endpoint = aiplatform.Endpoint(endpoint_path) arena_output: list[str] = [] start_time = time.time() try: logging.info(f"Calling endpoint: {endpoint_path}") response = endpoint.predict( instances=instances, parameters=parameters, ) # logging.info(f"Received response from endpoint: {response}") if not response or not hasattr(response, 'predictions') or not response.predictions: logging.error("Received empty or invalid response from endpoint.") return [] image_outputs = [] for prediction in response.predictions: # Check common keys for base64 image data img_data = prediction.get("output") or prediction.get("bytesBase64Encoded") if img_data: image_outputs.append(img_data) else: logging.warning(f"Prediction missing expected image data key ('output' or 'bytesBase64Encoded'): {prediction}") if not image_outputs: logging.error("No valid image data found in any endpoint predictions.") return [] # Or raise an error except Exception as e: logging.error(f"Error calling Vertex AI endpoint {endpoint_path}: {e}", exc_info=True) raise end_time = time.time() elapsed_time = end_time - start_time logging.info(f"Endpoint call finished in {elapsed_time:.2f} seconds. Processing {len(image_outputs)} images.") for idx, img_base64 in enumerate(image_outputs): try: image_filename = f"{uuid.uuid4()}.png" gcs_path_suffix = store_to_gcs( folder=output_gcs_folder, file_name=image_filename, mime_type="image/png", contents=img_base64, decode=True ) # Construct full GCS URI gcs_uri = f"gs://{gcs_path_suffix}" logging.info( f"Generated image {idx+1}/{len(image_outputs)} with model {model_name}. " f"Stored at: {gcs_uri}" ) arena_output.append(gcs_uri) try: add_image_metadata(gcs_uri, prompt, model_name) logging.debug(f"Successfully added metadata for {gcs_uri}") except Exception as ex: if "DeadlineExceeded" in str(ex): logging.error(f"Firestore timeout adding metadata for {gcs_uri}: {ex}") else: logging.error(f"Error adding image metadata for {gcs_uri}: {ex}", exc_info=True) except Exception as ex: logging.error(f"Error processing or uploading image {idx+1} from {model_name}: {ex}", exc_info=True) # Continue with the next image logging.info(f"Finished endpoint processing for model {model_name}. Returning {len(arena_output)} GCS URIs.") return arena_output def images_from_flux(model_name: str, prompt: str, aspect_ratio: str) -> list[str]: """ Generates images using the configured Flux.1 model endpoint. Args: prompt: The text prompt. params_override: Optional dictionary to override default parameters. Returns: A list of GCS URIs for the generated images. """ _ = aspect_ratio # aspect ratio is not used in this function if not config.MODEL_FLUX1_ENDPOINT_ID: raise ValueError("config.MODEL_FLUX1_ENDPOINT_ID is not set.") default_params = { "height": 1024, "width": 1024, "num_inference_steps": 4, # Default for Flux } return generate_images_from_model_garden( prompt=prompt, endpoint_id=config.MODEL_FLUX1_ENDPOINT_ID, model_name=model_name, output_gcs_folder="flux1", parameters=default_params, ) def images_from_stable_diffusion(model_name: str, prompt: str, aspect_ratio: str) -> list[str]: """ Generates images using the configured Stable Diffusion model endpoint. *Adjust default_params based on your specific Stable Diffusion deployment.* Args: prompt: The text prompt. params_override: Optional dictionary to override default parameters. Returns: A list of GCS URIs for the generated images. """ _ = aspect_ratio # aspect ratio is not used in this function if not config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID: raise ValueError("config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID is not set.") default_params = { "height": 1024, "width": 1024, "num_inference_steps": 25, # Typically higher for SD "guidance_scale": 7.5, # Common SD parameter } return generate_images_from_model_garden( prompt=prompt, endpoint_id=config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID, model_name=model_name, output_gcs_folder="stablediffusion", parameters=default_params, ) def images_from_imagen(model_name: str, prompt: str, aspect_ratio: str): """creates images from Imagen and returns a list of gcs uris Args: model_name (str): imagen model name prompt (str): prompt for t2i model aspect_ratio (str): aspect ratio string Returns: _type_: a list of strings (gcs uris of image output) """ start_time = time.time() arena_output = [] logging.info(f"model: {model_name}") logging.info(f"prompt: {prompt}") logging.info(f"target output: {config.GENMEDIA_BUCKET}") vertexai.init(project=config.PROJECT_ID, location=config.LOCATION) image_model = ImageGenerationModel.from_pretrained(model_name) response = image_model.generate_images( prompt=prompt, add_watermark=True, # aspect_ratio=getattr(state, "image_aspect_ratio"), aspect_ratio=aspect_ratio, number_of_images=1, output_gcs_uri=f"gs://{config.GENMEDIA_BUCKET}/imagen_live", language="auto", # negative_prompt=state.image_negative_prompt_input, safety_filter_level="block_few", # include_rai_reason=True, ) end_time = time.time() elapsed_time = end_time - start_time for idx, img in enumerate(response.images): logging.info(f"Generated image {idx} with model {model_name} in {elapsed_time:.2f} seconds") logging.info( f"Generated image: #{idx}, len {len(img._as_base64_string())} at {img._gcs_uri}" ) # output = img._as_base64_string() # state.image_output.append(output) arena_output.append(img._gcs_uri) logging.info(f"Image created: {img._gcs_uri}") try: add_image_metadata(img._gcs_uri, prompt, model_name) except Exception as e: if "DeadlineExceeded" in str(e): # Check for timeout error logging.error(f"Firestore timeout: {e}") else: logging.error(f"Error adding image metadata: {e}") return arena_output def study_fetch(model_name: str, prompt: str) -> list[str]: db: Client = FirebaseClient(database_id=config.IMAGE_FIREBASE_DB).get_client() collection_ref = db.collection(config.IMAGE_COLLECTION_NAME) print(f"Using: {model_name}") query = collection_ref.where(filter=FieldFilter("prompt", "==", prompt)).where(filter=FieldFilter("model", "==", model_name)).stream() docs = [] for doc in query: gs_uri = doc.to_dict()['gcsuri'] if "stablediffusion" not in gs_uri: docs.append(os.path.splitext(gs_uri)[0]) else: if gs_uri.startswith("20250328_"): docs.append(os.path.splitext(gs_uri)[0]) else: docs.append(gs_uri) return random.sample(docs, 1) if __name__ == "__main__": # Example usage prompt = "A futuristic city skyline at sunset" aspect_ratio = "16:9" model_name = config.MODEL_FLUX1 images = images_from_flux(model_name, prompt, aspect_ratio) print("Generated images:", images[0])