experiments/arena/models/set_up.py (44 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from dotenv import load_dotenv from google import genai import threading from config.default import Default load_dotenv(override=True) config = Default() def load_default_models() -> list[str]: IMAGE_GEN_MODELS = [config.MODEL_IMAGEN2, config.MODEL_IMAGEN3_FAST, config.MODEL_IMAGEN3, config.MODEL_IMAGEN32,] if config.MODEL_FLUX1_ENDPOINT_ID: IMAGE_GEN_MODELS.append(config.MODEL_FLUX1) if config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID: IMAGE_GEN_MODELS.append(config.MODEL_STABLE_DIFFUSION) return IMAGE_GEN_MODELS class ModelSetup: """Model set up class with caching and thread safety.""" _client_cache = {} _lock = threading.Lock() @staticmethod def init( project_id: Optional[str] = None, location: Optional[str] = None, model_id: Optional[str] = None, ): """Initializes common model settings with caching and thread safety.""" if not project_id: project_id = config.PROJECT_ID if not location: location = config.LOCATION if not model_id: model_id = config.MODEL_ID if None in [project_id, location, model_id]: raise ValueError("All parameters must be set.") cache_key = (project_id, location, model_id) with ModelSetup._lock: # Acquire lock for thread safety if cache_key not in ModelSetup._client_cache: print(f"Initiating genai client with {project_id} in {location} using model: {model_id}") client = genai.Client( vertexai=config.INIT_VERTEX, project=project_id, location=location, ) ModelSetup._client_cache[cache_key] = client else: print(f"Using cached genai client for {project_id} in {location} using model: {model_id}") return ModelSetup._client_cache[cache_key], model_id