experiments/arena/common/metadata.py (210 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import json import os from typing import Optional, Dict, Any, List import pandas as pd from google.cloud import firestore from config.default import Default from config.firebase_config import FirebaseClient from config.spanner_config import ArenaStudyTracker, ArenaModelEvaluation from models.set_up import ModelSetup from common.storage import check_gcs_blob_exists from alive_progress import alive_bar from utils.logger import LogLevel, log # Initialize configuration client, model_id = ModelSetup.init() MODEL_ID = model_id config = Default() db = FirebaseClient(database_id=config.IMAGE_FIREBASE_DB).get_client() def add_image_metadata(gcsuri: str, prompt: str, model: str, study: Optional[str] = "live", collection_name: Optional[str] = None): """Add Image metadata to Firestore persistence""" if collection_name is None: collection_name = config.IMAGE_COLLECTION_NAME print(f"Using Firestore collection: {collection_name}") current_datetime = datetime.datetime.now() # Store the image metadata in Firestore doc_ref = db.collection(collection_name).document() try: doc_ref.set( { "gcsuri": gcsuri, "study": study, "prompt": prompt, "model": model, "timestamp": current_datetime, # alt: firestore.SERVER_TIMESTAMP } ) except Exception as e: print(f"Error storing image metadata: {e}") return print(f"Image data stored in Firestore with document ID: {doc_ref.id}") def load_metadata_from_json( collection_name: str, json_file_path: str, top_level_key: str, gcs_sub_folder: str, model_name: str, key_mapping: Optional[Dict[Any, str]] = None, ) -> None: """ Loads metadata from a JSON file and adds it to Firestore using add_image_metadata, with a progress bar. Args: collection_name: The name of the Firestore collection to store metadata in. json_file_path: Path to the JSON file containing the metadata. top_level_key: The key in the JSON that contains the list of metadata entries. gcs_sub_folder: The sub-folder within the GCS bucket where the images are located. model_name: The model to associate with the metadata entries. key_mapping: An optional dictionary to map keys (or indices) from the JSON structure to the expected arguments of `add_image_metadata`. For the given example, it would be `{0: 'prompt', 1: 'images'}`. The value associated with 'images' is expected to be a list of image identifiers. """ if key_mapping is None: key_mapping = {0: "prompt", 1: "images"} # Validate: Ensure the file exists if not os.path.exists(json_file_path): raise FileNotFoundError(f"Metadata file not found: {json_file_path}") with open(json_file_path, "r", encoding="utf-8") as f: metadata = json.load(f) data_list = metadata.get(top_level_key, []) if not data_list: raise ValueError(f"No data found under the key '{top_level_key}' in the provided JSON file.") total_items = len(data_list) with alive_bar(total_items, title="Processing Metadata") as bar: for item in data_list: if not isinstance(item, (list, tuple)) or len(item) < 2: print(f"Skipping invalid item format: {item}. Expected a list or tuple with at least two elements.") bar() # Increment the progress bar continue prompt_key = key_mapping.get(0) images_key = key_mapping.get(1) if prompt_key is None or images_key is None: raise ValueError("Key mapping must include keys for both 'prompt' (typically index 0) and 'images' (typically index 1).") prompt = item[0] images = item[1] if prompt is None: print(f"Skipping item with missing prompt: {item}") bar() # Increment the progress bar continue if not isinstance(images, list) or not images: print(f"No images found for prompt: '{prompt}'. Skipping...") bar() # Increment the progress bar continue print(f"Processing prompt: 'Found {len(images)} potential {'image' if len(images) == 1 else 'images'}...") print(f"Images ID(s): {images}") print(f"Sub-folder: {gcs_sub_folder}") print(f"Model: {model_name}") selected_image = None for image_id in images: gcs_uri = f"gs://{Default.GENMEDIA_BUCKET}/{gcs_sub_folder}/{image_id}" if check_gcs_blob_exists(gcs_uri): print(f"Selected image: {image_id} exists in GCS.") selected_image = image_id selected_image_gcsuri = gcs_uri break else: print(f"No valid images found in GCS for prompt: '{prompt}'. Skipping...") bar() # Increment the progress bar continue print(f"Adding metadata for prompt: '{prompt}' with image URI: {selected_image_gcsuri}...") add_image_metadata(collection_name=collection_name, gcsuri=selected_image_gcsuri, prompt=prompt, model=model_name) bar() # Increment the progress bar def get_elo_ratings(study: str): """ Retrieve ELO ratings for models from Firestore """ # Fetch current ELO ratings from Firestore doc_ref = ( db.collection(config.IMAGE_RATINGS_COLLECTION_NAME) .where(filter=firestore.FieldFilter("study", "==", study)) .where(filter=firestore.FieldFilter("type", "==", "elo_rating")) .get() ) updated_ratings = {} if doc_ref: for doc in doc_ref: ratings = doc.to_dict().get("ratings", {}) updated_ratings.update(ratings) # Convert to DataFrame df = pd.DataFrame(list(updated_ratings.items()), columns=['Model', 'ELO Rating']) df = df.sort_values(by='ELO Rating', ascending=False) # Sort by rating df.reset_index(drop=True, inplace=True) # Reset index return df def update_elo_ratings(model1: str, model2: str, winner: str, images: list[str], prompt: str, study: str): """Update ELO ratings for models""" current_datetime = datetime.datetime.now() # Fetch current ELO ratings from Firestore doc_ref = ( db.collection(config.IMAGE_RATINGS_COLLECTION_NAME) .where(filter=firestore.FieldFilter("study", "==", study)) .where(filter=firestore.FieldFilter("type", "==", "elo_rating")) .get() ) updated_ratings = {} elo_rating_doc_id = None # Store the document ID if doc_ref: for doc in doc_ref: elo_rating_doc_id = doc.id # Get the document ID ratings = doc.to_dict().get("ratings", {}) updated_ratings.update(ratings) elo_model1 = updated_ratings.get(model1, 1000) # Default to 1000 if not found elo_model2 = updated_ratings.get(model2, 1000) # Calculate expected scores expected_model1 = 1 / (1 + 10 ** ((elo_model2 - elo_model1) / 400)) expected_model2 = 1 / (1 + 10 ** ((elo_model1 - elo_model2) / 400)) # Update ELO ratings based on the winner k_factor = config.ELO_K_FACTOR if winner == model1: elo_model1 = elo_model1 + k_factor * (1 - expected_model1) elo_model2 = elo_model2 + k_factor * (0 - expected_model2) elif winner == model2: elo_model1 = elo_model1 + k_factor * (0 - expected_model1) elo_model2 = elo_model2 + k_factor * (1 - expected_model2) updated_ratings[model1] = round(elo_model1, 2) updated_ratings[model2] = round(elo_model2, 2) print(f"Ratings: {updated_ratings}") # Store updated ELO ratings in Firestore if elo_rating_doc_id: # Check if the document ID was found doc_ref = db.collection(config.IMAGE_RATINGS_COLLECTION_NAME).document(elo_rating_doc_id) doc_ref.update( { "ratings": updated_ratings, "timestamp": current_datetime, } ) print(f"ELO ratings updated in Firestore with document ID: {doc_ref.id}") else: # Document doesn't exist, create it doc_ref = db.collection(config.IMAGE_RATINGS_COLLECTION_NAME).document() doc_ref.set( { "study": study, "type": "elo_rating", "ratings": updated_ratings, "timestamp": current_datetime, } ) print(f"ELO ratings created in Firestore with document ID: {doc_ref.id}") doc_ref = db.collection(config.IMAGE_RATINGS_COLLECTION_NAME).document() doc_ref.set( { "timestamp": current_datetime, "type": "vote", "model1": model1, "image1": images[0], "model2": model2, "image2": images[1], "winner": winner, "prompt": prompt, "study": study } ) print(f"Vote updated in Firestore with document ID: {doc_ref.id}") # Update the latest ELO ratings in Spanner study_tracker = ArenaStudyTracker( project_id=config.PROJECT_ID, spanner_instance_id=config.SPANNER_INSTANCE_ID, spanner_database_id=config.SPANNER_DATABASE_ID, ) if not study_tracker: log("Failed to initialize Spanner study tracker.", LogLevel.ERROR) raise RuntimeError("Spanner study tracker initialization failed.") elo_ratings_by_model = [] for model, elo in updated_ratings.items(): elo_study_entry = ArenaModelEvaluation(model_name=model, rating=elo, study=study) elo_ratings_by_model.append(elo_study_entry) try: study_tracker.upsert_study_runs(study_runs=elo_ratings_by_model) log(f"ELO ratings updated in Spanner for study '{study}'.", LogLevel.ON) except Exception as e: log(f"Failed to update ELO ratings in Spanner: {e}", LogLevel.ERROR) raise RuntimeError(f"Failed to update ELO ratings in Spanner: {e}") def get_latest_votes(study: str, limit: int = 10): """Retrieve the latest votes from Firestore, ordered by timestamp in descending order.""" try: votes_ref = ( db.collection(config.IMAGE_RATINGS_COLLECTION_NAME) .where(filter=firestore.FieldFilter("study", "==", study)) .where(filter=firestore.FieldFilter("type", "==", "vote")) .order_by("timestamp", direction=firestore.Query.DESCENDING) .limit(limit) ) votes = [] for doc in votes_ref.stream(): votes.append(doc.to_dict()) return votes except Exception as e: print(f"Error fetching votes: {e}") return []