experiments/arena/config/default.py (57 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Default Configuration for GenMedia Arena """
import json
import os
from dataclasses import asdict, dataclass, field
from dotenv import load_dotenv
load_dotenv(override=True)
@dataclass
class GeminiModelConfig:
"""Configuration specific to Gemini models"""
@dataclass
class Default:
"""All configuration variables for this application are managed here."""
# pylint: disable=invalid-name
PROJECT_ID: str = os.environ.get("PROJECT_ID")
LOCATION: str = os.environ.get("LOCATION", "us-central1")
MODEL_ID: str = os.environ.get("MODEL_ID", "gemini-2.0-flash")
INIT_VERTEX: bool = os.environ.get("INIT_VERTEX", "True").lower() in ("true", "1")
GENMEDIA_BUCKET: str = os.environ.get("GENMEDIA_BUCKET")
PUBLIC_BUCKET: bool = os.environ.get("PUBLIC_BUCKET", "False").lower() in ("true", "1")
SHOW_RESULTS_PAUSE_TIME: int = int(os.environ.get("SHOW_RESULTS_PAUSE_TIME", "1"))
IMAGE_FIREBASE_DB: str = os.environ.get("IMAGE_FIREBASE_DB")
IMAGE_COLLECTION_NAME = os.environ.get("IMAGE_COLLECTION_NAME")
STUDY_COLLECTION_NAME: str = os.environ.get("STUDY_COLLECTION_NAME", "arena_study")
IMAGE_RATINGS_COLLECTION_NAME: str = os.environ.get("IMAGE_RATINGS_COLLECTION_NAME", "arena_elo")
STABLE_DIFFUSION_DB_PROMPTS: str = os.environ.get("STABLE_DIFFUSION_DB_PROMPTS", "prompts/stable_diffusion_prompts.json")
DEFAULT_PROMPTS: str = os.environ.get("DEFAULT_PROMPTS", "prompts/imagen_prompts.json")
DEFAULT_STUDY_NAME: str = os.environ.get("DEFAULT_STUDY_NAME", "live")
ELO_K_FACTOR: int = int(os.environ.get("ELO_K_FACTOR", 32))
# image models
MODEL_IMAGEN2: str = "imagegeneration@006"
MODEL_IMAGEN3_FAST: str = "imagen-3.0-fast-generate-001"
MODEL_IMAGEN3: str = "imagen-3.0-generate-001"
MODEL_IMAGEN32: str = "imagen-3.0-generate-002"
MODEL_GEMINI2: str = "gemini-2.0-flash"
# model garden image models
MODEL_FLUX1: str = "black-forest-labs/flux1-schnell"
MODEL_FLUX1_ENDPOINT_ID: str = os.environ.get("MODEL_FLUX1_ENDPOINT_ID")
MODEL_STABLE_DIFFUSION: str = "stability-ai/stable-diffusion-2-1"
MODEL_STABLE_DIFFUSION_ENDPOINT_ID: str = os.environ.get("MODEL_STABLE_DIFFUSION_ENDPOINT_ID")
# Spanner related variables
SPANNER_INSTANCE_ID: str = os.environ.get("SPANNER_INSTANCE_ID", "arena")
SPANNER_DATABASE_ID: str = os.environ.get("SPANNER_DATABASE_ID", "study")
SPANNER_TIMEOUT: int = int(os.environ.get("SPANNER_TIMEOUT", 300)) # seconds
def __post_init__(self):
"""Validates the configuration variables after initialization."""
if not self.PROJECT_ID:
raise ValueError("PROJECT_ID environment variable is not set.")
if not self.GENMEDIA_BUCKET:
raise ValueError("GENMEDIA_BUCKET environment variable is not set.")
if not self.MODEL_FLUX1_ENDPOINT_ID:
print("MODEL_FLUX1_ENDPOINT_ID environment variable is not set. List of models will exclude flux1") # Optional: List of models will exclude flux1
if not self.MODEL_STABLE_DIFFUSION_ENDPOINT_ID:
print("MODEL_STABLE_DIFFUSION_ENDPOINT_ID environment variable is not set. List of models will exclude stable diffusion")
if self.ELO_K_FACTOR <= 0:
raise ValueError("ELO_K_FACTOR must be a positive integer.")
if not self.IMAGE_FIREBASE_DB:
raise ValueError("IMAGE_FIREBASE_DB environment variable is not set. Default will be used")
if not self.IMAGE_COLLECTION_NAME:
raise ValueError("IMAGE_COLLECTION_NAME environment variable is not set.")
valid_locations = ["us-central1", "us-east4", "europe-west4", "asia-east1"] # example locations
if self.LOCATION not in valid_locations:
print(f"Warning: LOCATION {self.LOCATION} may not be valid.")
print("Configuration validated successfully.")
def __repr__(self):
return f"Default({json.dumps(asdict(self), indent=4)})"
# pylint: disable=invalid-name