experiments/arena/pages/arena.py (411 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import field import random import logging import time from concurrent.futures import ThreadPoolExecutor, as_completed import mesop as me from common.metadata import update_elo_ratings from config.default import Default from prompts.utils import PromptManager from state.state import AppState from components.header import header from models.set_up import ModelSetup, load_default_models from models.gemini_model import ( generate_content, generate_images, ) from models.generate import images_from_flux, images_from_imagen, images_from_stable_diffusion, study_fetch # Initialize configuration client, model_id = ModelSetup.init() MODEL_ID = model_id config = Default() prompt_manager = PromptManager() logging.basicConfig(level=logging.DEBUG) IMAGEN_MODELS = [config.MODEL_IMAGEN2, config.MODEL_IMAGEN3_FAST, config.MODEL_IMAGEN3, config.MODEL_IMAGEN32,] GEMINI_MODELS = [config.MODEL_GEMINI2] @me.stateclass class PageState: """Local Page State""" temp_name: str = "" is_loading: bool = False # pylint: disable=invalid-field-call arena_prompt: str = "" image_negative_prompt_input: str = "" image_aspect_ratio: str = "1:1" arena_textarea_key: int = 0 arena_model1: str = "" arena_model2: str = "" arena_output: list[str] = field(default_factory=lambda: []) chosen_model: str = "" study: str = "live" study_models: list[str] = field(default_factory=list) # pylint: disable=invalid-field-call def arena_images(input: str, study: str): """Create images for arena comparison""" state = me.state(PageState) if input == "": # handle condition where someone hits "random" but doesn't modify if state.arena_prompt != "": input = state.arena_prompt state.arena_output.clear() logging.info("BATTLE: %s vs. %s", state.arena_model1, state.arena_model2) prompt = input logging.info("prompt: %s", prompt) if state.image_negative_prompt_input: logging.info("negative prompt: %s", state.image_negative_prompt_input) with ThreadPoolExecutor() as executor: # Create a thread pool futures = [] if study == "live": # model 1 if state.arena_model1 in IMAGEN_MODELS: logging.info("model 1: %s", state.arena_model1) futures.append( executor.submit( images_from_imagen, state.arena_model1, prompt, state.image_aspect_ratio, ) ) elif state.arena_model1.startswith(config.MODEL_GEMINI2): logging.info("model 1: %s", state.arena_model1) futures.append( executor.submit( generate_images, prompt, ) ) elif state.arena_model1.startswith(config.MODEL_FLUX1): if config.MODEL_FLUX1_ENDPOINT_ID: logging.info("model 1: %s", state.arena_model1) futures.append( executor.submit( images_from_flux, state.arena_model1, prompt, state.image_aspect_ratio, ) ) else: logging.error("no endpoint defined for %s", state.arena_model1) elif state.arena_model1.startswith(config.MODEL_STABLE_DIFFUSION): if config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID: logging.info("model 1: %s", state.arena_model1) futures.append( executor.submit( images_from_stable_diffusion, state.arena_model1, prompt, state.image_aspect_ratio, ) ) else: logging.error("no endpoint defined for %s", state.arena_model1) # model 2 if state.arena_model2 in IMAGEN_MODELS: logging.info("model 2: %s", state.arena_model2) futures.append( executor.submit( images_from_imagen, state.arena_model2, prompt, state.image_aspect_ratio, ) ) elif state.arena_model2.startswith(config.MODEL_GEMINI2): logging.info("model 2: %s", state.arena_model2) futures.append( executor.submit( generate_images, prompt, ) ) elif state.arena_model2.startswith(config.MODEL_FLUX1): if config.MODEL_FLUX1_ENDPOINT_ID: logging.info("model 2: %s", state.arena_model2) futures.append( executor.submit( images_from_flux, state.arena_model2, prompt, state.image_aspect_ratio, ) ) else: logging.error("no endpoint defined for %s", state.arena_model2) elif state.arena_model2.startswith(config.MODEL_STABLE_DIFFUSION): if config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID: logging.info("model 2: %s", state.arena_model2) futures.append( executor.submit( images_from_stable_diffusion, state.arena_model2, prompt, state.image_aspect_ratio, ) ) else: logging.error("no endpoint defined for %s", state.arena_model2) # Fetch images from study else: futures.extend([ executor.submit( study_fetch, state.arena_model1, prompt ), executor.submit( study_fetch, state.arena_model2, prompt ) ]) for future in as_completed(futures): # Wait for tasks to complete try: result = future.result() # Get the result of each task state.arena_output.extend( result ) # Assuming images_from_imagen returns a list except Exception as e: logging.error(f"Error during image generation: {e}") def on_click_reload_arena(e: me.ClickEvent): # pylint: disable=unused-argument """Reload arena handler""" state = me.state(PageState) if state.study == "live": state.study_models = load_default_models() state.arena_prompt = prompt_manager.random_prompt() state.arena_output.clear() state.is_loading = True yield print(f"Use {state.study_models}") # get random images state.arena_model1, state.arena_model2 = random.sample(state.study_models, 2) logging.info("%s vs. %s", state.arena_model1, state.arena_model2) arena_images(state.arena_prompt, state.study) state.is_loading = False yield def on_click_arena_vote(e: me.ClickEvent): """Arena vote handler""" state = me.state(PageState) model_name = getattr(state, e.key) logging.info("user preferred %s: %s", e.key, model_name) state.chosen_model = model_name yield # update the elo ratings update_elo_ratings(state.arena_model1, state.arena_model2, model_name, state.arena_output, state.arena_prompt, state.study) yield time.sleep(int(Default.SHOW_RESULTS_PAUSE_TIME)) yield # clear the output and reload state.arena_output.clear() state.chosen_model = "" state.arena_prompt = prompt_manager.random_prompt() state.arena_model1, state.arena_model2 = random.sample(state.study_models, 2) yield arena_images(state.arena_prompt, state.study) yield WELCOME_PROMPT = """ Welcome the user to the battle of the generative media images, and encourage participation by asserting their voting on the images presented. This should be one or two sentences. """ def reload_welcome(e: me.ClickEvent): # pylint: disable=unused-argument """Handle regeneration of welcome message event""" app_state = me.state(AppState) app_state.welcome_message = generate_welcome() yield def generate_welcome() -> str: """Generate a nice welcome message with Gemini 2.0""" return generate_content(WELCOME_PROMPT) def arena_page_content(app_state: me.state): """Arena Mesop Page""" page_state = me.state(PageState) prompt_manager.prompts_location = app_state.study_prompts_location page_state.study = app_state.study if page_state.study == "live": app_state.study_models = load_default_models() page_state.study_models = app_state.study_models print(f"======> Starting Page state study models: {page_state.study_models}") # TODO this is an initialization function that should be extracted if not app_state.welcome_message: app_state.welcome_message = generate_welcome() if not page_state.arena_prompt: page_state.arena_prompt = prompt_manager.random_prompt() page_state.arena_model1, page_state.arena_model2 = random.sample(app_state.study_models, 2) arena_images(page_state.arena_prompt, app_state.study) with me.box( style=me.Style( display="flex", flex_direction="column", height="100%", ), ): with me.box( style=me.Style( background=me.theme_var("background"), height="100%", overflow_y="scroll", margin=me.Margin(bottom=20), ) ): with me.box( style=me.Style( background=me.theme_var("background"), padding=me.Padding(top=24, left=24, right=24, bottom=24), display="flex", flex_direction="column", ) ): header("Arena" + (f" [Active Study: {app_state.study}]" if app_state.study != "live" else ""), "stadium") # welcome message with me.box( style=me.Style( flex_grow=1, display="flex", align_items="center", justify_content="center", ), on_click=reload_welcome, ): me.text( app_state.welcome_message, style=me.Style( width="80vw", font_size="12pt", font_style="italic", color="gray", ), ) me.box(style=me.Style(height="16px")) with me.box( style=me.Style( margin=me.Margin(left="auto", right="auto"), width="min(1024px, 100%)", gap="24px", flex_grow=1, display="flex", flex_wrap="wrap", flex_direction="column", align_items="center", ) ): # Prompt with me.box( style=me.Style( display="flex", flex_direction="column", align_items="center", width="85%", ) ): me.text( "Select the output you prefer for the given prompt", style=me.Style(font_weight=500, font_size="20px", text_transform="uppercase"), ) me.box(style=me.Style(height=16)) me.text(page_state.arena_prompt, style=me.Style(font_size="20pt")) # Image outputs with me.box(style=_BOX_STYLE): if page_state.is_loading: with me.box( style=me.Style( display="grid", justify_content="center", justify_items="center", ) ): me.progress_spinner() if len(page_state.arena_output) != 0: with me.box( style=me.Style( display="grid", justify_content="center", justify_items="center", ) ): # Generated images row with me.box( style=me.Style( flex_wrap="wrap", display="flex", gap="15px" ) ): for idx, img in enumerate(page_state.arena_output, start=1): print(f"===> idx: {idx}, img: {img}") model_name = f"arena_model{idx}" model_value = getattr(page_state, model_name) replace_url = "https://storage.mtls.cloud.google.com/" if Default.PUBLIC_BUCKET: replace_url = "https://storage.googleapis.com/" img_url = img.replace( "gs://", replace_url ) with me.box( style=me.Style(align_items="center", justify_content="center", display="flex", flex_direction="column"), ): image_border_style = me.Style( width="450px", margin=me.Margin(top=10), border_radius="35px", ) if page_state.chosen_model: if page_state.chosen_model == model_value: # green border image_border_style = me.Style( width="450px", margin=me.Margin(top=10), border_radius="35px", border=me.Border().all(me.BorderSide(color="green", style="inset", width="5px")) ) else: # opaque image_border_style = me.Style( width="450px", margin=me.Margin(top=10), border_radius="35px", opacity=0.5, ) me.image( src=f"{img_url}", style=image_border_style, ) if page_state.chosen_model: text_style = me.Style() if page_state.chosen_model == model_value: text_style = me.Style(font_weight="bold") me.text(model_value, style=text_style) else: me.box(style=me.Style(height=18)) me.box(style=me.Style(height=15)) if len(page_state.arena_output) != 2: disabled_choice = True else: disabled_choice = False with me.box( style=me.Style( flex_direction="row", display="flex", gap=50, ) ): # left choice button with me.content_button( type="flat", key="arena_model1", on_click=on_click_arena_vote, disabled=disabled_choice, ): with me.box( style=me.Style( display="flex", align_items="center" ) ): me.icon("arrow_left") me.text("left") # skip button me.button( label="skip", type="stroked", on_click=on_click_reload_arena, ) # right choice button with me.content_button( type="flat", key="arena_model2", on_click=on_click_arena_vote, disabled=disabled_choice, ): with me.box( style=me.Style( display="flex", align_items="center" ) ): me.text("right") me.icon("arrow_right") else: # skip button me.button( label="skip", type="stroked", on_click=on_click_reload_arena, ) # show user choice if page_state.chosen_model: me.text(f"You voted {page_state.chosen_model}") _BOX_STYLE = me.Style( flex_basis="max(480px, calc(50% - 48px))", background=me.theme_var("background"), border_radius=12, box_shadow=("0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f"), padding=me.Padding(top=16, left=16, right=16, bottom=16), display="flex", flex_direction="column", width="100%", )