def arena_images()

in experiments/arena/pages/arena.py [0:0]


def arena_images(input: str, study: str):
    """Create images for arena comparison"""
    state = me.state(PageState)
    if input == "":  # handle condition where someone hits "random" but doesn't modify
        if state.arena_prompt != "":
            input = state.arena_prompt
    state.arena_output.clear()

    logging.info("BATTLE: %s vs. %s", state.arena_model1, state.arena_model2)

    prompt = input
    logging.info("prompt: %s", prompt)
    if state.image_negative_prompt_input:
        logging.info("negative prompt: %s", state.image_negative_prompt_input)
    
    with ThreadPoolExecutor() as executor:  # Create a thread pool
        futures = []
        if study == "live":
            # model 1
            if state.arena_model1 in IMAGEN_MODELS:
                logging.info("model 1: %s", state.arena_model1)
                futures.append(
                    executor.submit(
                        images_from_imagen,
                        state.arena_model1,
                        prompt,
                        state.image_aspect_ratio,
                    )
                )
            elif state.arena_model1.startswith(config.MODEL_GEMINI2):
                logging.info("model 1: %s", state.arena_model1)
                futures.append(
                    executor.submit(
                        generate_images,
                        prompt,
                    )
                )
            elif state.arena_model1.startswith(config.MODEL_FLUX1):
                if config.MODEL_FLUX1_ENDPOINT_ID:
                    logging.info("model 1: %s", state.arena_model1)
                    futures.append(
                        executor.submit(
                            images_from_flux,
                            state.arena_model1,
                            prompt,
                            state.image_aspect_ratio,
                        )
                    )
                else:
                    logging.error("no endpoint defined for %s", state.arena_model1)
            elif state.arena_model1.startswith(config.MODEL_STABLE_DIFFUSION):
                if config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID:
                    logging.info("model 1: %s", state.arena_model1)
                    futures.append(
                        executor.submit(
                            images_from_stable_diffusion,
                            state.arena_model1,
                            prompt,
                            state.image_aspect_ratio,
                        )
                    )
                else:
                    logging.error("no endpoint defined for %s", state.arena_model1)

            # model 2
            if state.arena_model2 in IMAGEN_MODELS:
                logging.info("model 2: %s", state.arena_model2)
                futures.append(
                    executor.submit(
                        images_from_imagen,
                        state.arena_model2,
                        prompt,
                        state.image_aspect_ratio,
                    )
                )
            elif state.arena_model2.startswith(config.MODEL_GEMINI2):
                logging.info("model 2: %s", state.arena_model2)
                futures.append(
                    executor.submit(
                        generate_images,
                        prompt,
                    )
                )
            elif state.arena_model2.startswith(config.MODEL_FLUX1):
                if config.MODEL_FLUX1_ENDPOINT_ID:
                    logging.info("model 2: %s", state.arena_model2)
                    futures.append(
                        executor.submit(
                            images_from_flux,
                            state.arena_model2,
                            prompt,
                            state.image_aspect_ratio,
                        )
                    )
                else:
                    logging.error("no endpoint defined for %s", state.arena_model2)
            elif state.arena_model2.startswith(config.MODEL_STABLE_DIFFUSION):
                if config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID:
                    logging.info("model 2: %s", state.arena_model2)
                    futures.append(
                        executor.submit(
                            images_from_stable_diffusion,
                            state.arena_model2,
                            prompt,
                            state.image_aspect_ratio,
                        )
                    )
                else:
                    logging.error("no endpoint defined for %s", state.arena_model2)
        # Fetch images from study
        else:
            futures.extend([
                executor.submit(
                    study_fetch,
                    state.arena_model1,
                    prompt
                ),
                executor.submit(
                    study_fetch,
                    state.arena_model2,
                    prompt
                )
            ])
        
        for future in as_completed(futures):  # Wait for tasks to complete
            try:
                result = future.result()  # Get the result of each task
                state.arena_output.extend(
                    result
                )  # Assuming images_from_imagen returns a list
            except Exception as e:
                logging.error(f"Error during image generation: {e}")