config/default.py (74 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Default Configuration for Creative Studio """ import os from dataclasses import dataclass, field from vertexai.generative_models import HarmBlockThreshold from models.image_models import ImageModel @dataclass class GeminiModelConfig: """Configuration specific to Gemini models.""" generation: dict = field(default_factory=dict) safety_settings: dict = field(default_factory=dict) tools: dict = field(default_factory=dict) grounding_source: object = None def __repr__(self): params = [] for k, v in self.generation.items(): params.append(f"generation_{k}={v}") for k, v in self.safety_settings.items(): params.append(f"safety_{k}={v}") for k, v in self.tools.items(): params.append(f"tools_{k}={v}") if self.grounding_source: params.append("grounding=ON") return f"ModelConfig({', '.join(params)})" @dataclass class Config: """All configuration variables for this solution should be managed here.""" TITLE = "IMAGEN CREATIVE STUDIO" IMAGE_CREATION_BUCKET = os.environ.get("IMAGE_CREATION_BUCKET", "") PROJECT_ID = os.environ.get("PROJECT_ID", "") LOCATION = os.getenv("LOCATION", "us-central1") MODEL_GEMINI_MULTIMODAL = "gemini-1.5-flash" MODEL_IMAGEN2 = "imagegeneration@006" MODEL_IMAGEN_NANO = "imagegeneration@004" MODEL_IMAGEN3_FAST = "imagen-3.0-fast-generate-001" MODEL_IMAGEN3 = "imagen-3.0-generate-001" TEMPERATURE = 0.8 TOP_P = 0.97 TOP_K = 40 MAX_OUTPUT_TOKENS = 2048 IMAGEN_PROMPTS_JSON = "prompts/imagen_prompts.json" image_modifiers: list[str] = field( default_factory=lambda: [ "aspect_ratio", "content_type", "color_tone", "lighting", "composition", ] ) gemini_settings: GeminiModelConfig = field( default_factory=GeminiModelConfig, init=False ) display_image_models: list[ImageModel] = field( default_factory=lambda: [ {"display": "Imagen 3 Fast", "model_name": Config.MODEL_IMAGEN3_FAST}, {"display": "Imagen 3", "model_name": Config.MODEL_IMAGEN3}, ] ) def __post_init__(self): """Initialize fields that depend on other fields or require complex logic.""" self.gemini_settings.generation["temperature"] = self.TEMPERATURE self.gemini_settings.generation["top_p"] = self.TOP_P self.gemini_settings.generation["top_k"] = self.TOP_K self.gemini_settings.generation["max_output_tokens"] = self.MAX_OUTPUT_TOKENS self.gemini_settings.generation["candidate_count"] = 1 self.gemini_settings.generation["stop_sequences"] = [] self.gemini_settings.safety_settings["HARASSMENT"] = ( HarmBlockThreshold.BLOCK_ONLY_HIGH ) self.gemini_settings.safety_settings["HATE_SPEECH"] = ( HarmBlockThreshold.BLOCK_ONLY_HIGH ) self.gemini_settings.safety_settings["SEXUALLY_EXPLICIT"] = ( HarmBlockThreshold.BLOCK_ONLY_HIGH ) self.gemini_settings.safety_settings["DANGEROUS_CONTENT"] = ( HarmBlockThreshold.BLOCK_ONLY_HIGH )