experiments/arena/pages/settings.py (97 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import mesop as me from components.header import header from components.page_scaffold import ( page_scaffold, page_frame, ) from typing import Any from config.default import Default from config.firebase_config import FirebaseClient import asyncio from google.cloud.firestore import AsyncClient, FieldFilter cnfg = Default() db = FirebaseClient(cnfg.IMAGE_FIREBASE_DB).get_client() def settings_page_content(app_state: me.state): """Settings Mesop Page""" with page_scaffold(): # pylint: disable=not-context-manager with page_frame(): # pylint: disable=not-context-manager header("Settings", "settings") me.text(app_state.welcome_message, style=me.Style(font_style="italic")) me.box(style=me.Style(height=16)) _render_study_info(_get_studies(), app_state) me.box(style=me.Style(height=16)) me.text(f"Vote pause time: {Default.SHOW_RESULTS_PAUSE_TIME} seconds") async def _purge_elo_ratings(study: str) -> bool: """Reset the ELO Ratings""" db = AsyncClient(project=cnfg.PROJECT_ID, database=cnfg.IMAGE_FIREBASE_DB) batch_transcations = db.batch() ratings = db.collection("arena_elo") docs = ratings.where(filter=FieldFilter("study", "==", study)).stream() batches = [] idx = 0 async for doc in docs: doc_ref = doc.reference batch_transcations.delete(doc_ref) if((idx + 1) % 1000 == 0): batches.append(batch_transcations.commit(timeout=60)) batch_transcations = db.batch() idx += 1 idx = idx % 1000 if batch_transcations: batches.append(batch_transcations.commit(timeout=60)) results = await asyncio.gather(*batches) return all(results) def _get_studies() -> dict[dict[str, Any]]: """ Get all Studies """ studies = dict() docs = db.collection(cnfg.STUDY_COLLECTION_NAME).stream() for doc in docs: doc_content = doc.to_dict() studies.update({doc_content['label']: doc_content}) studies.update({"live": {"label": "live", "gcsuri": "imagen_prompts.json"}}) return studies def _render_study_info(studies: dict[dict[str, Any]], app_state: me.state): """Render the Mesop Studies""" def _handle_select(study: me.ClickEvent): app_state.study = study.key app_state.study_prompts_location = studies[study.key]['gcsuri'] app_state.study_models = studies[study.key].get('models', []) def _handle_purge(study: me.ClickEvent): asyncio.run(_purge_elo_ratings(study=study.key)) if len(studies): me.text("Available Studies", type="headline-5") for study in studies.keys(): with me.box(style=_BOX_STYLE): this_study = studies[study] study_modifier = "Available" if app_state.study == study: study_modifier = "Current" me.text(f"{study_modifier} Study: {this_study['label']}", style=me.Style(font_weight="bold")) me.box(style=me.Style(height=8)) models = this_study.get("models") model_list_items = "" if models: for model in models: model_list_items += f"<li>{model}</li>" else: model_list_items = """ <li>imagegeneration@006</li> <li>imagen-3.0-generate-002</li> <li>imagen-3.0-fast-generate-001</li> <li>black-forest-labs/flux1-schnell</li> <li>stability-ai/stable-diffusion-2-1</li> """ me.html(f"Models <ul>{model_list_items}</ul>") me.text(f"Prompt list: {this_study['gcsuri']}") #for key, value in studies[study].items(): # me.markdown(f"**{key}:** {value}") if app_state.study != study: me.button( label="Activate", on_click=lambda study=study: _handle_select(study), key=study, disabled=app_state.study == study, ) me.box(style=me.Style(height=16)) #me.divider(inset=False) me.box(style=me.Style(height=16)) me.button(label="Reset Leaderboard for current study", on_click=lambda study=app_state.study: _handle_purge(study), key=f"{app_state.study}") else: me.markdown("No Studies found") _BOX_STYLE = me.Style( flex_basis="max(480px, calc(50% - 48px))", background=me.theme_var("background"), border_radius=12, box_shadow=("0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f"), padding=me.Padding(top=16, left=16, right=16, bottom=16), display="flex", flex_direction="column", #width="100%", )