in experiments/arena/pages/arena.py [0:0]
def arena_images(input: str, study: str):
"""Create images for arena comparison"""
state = me.state(PageState)
if input == "": # handle condition where someone hits "random" but doesn't modify
if state.arena_prompt != "":
input = state.arena_prompt
state.arena_output.clear()
logging.info("BATTLE: %s vs. %s", state.arena_model1, state.arena_model2)
prompt = input
logging.info("prompt: %s", prompt)
if state.image_negative_prompt_input:
logging.info("negative prompt: %s", state.image_negative_prompt_input)
with ThreadPoolExecutor() as executor: # Create a thread pool
futures = []
if study == "live":
# model 1
if state.arena_model1 in IMAGEN_MODELS:
logging.info("model 1: %s", state.arena_model1)
futures.append(
executor.submit(
images_from_imagen,
state.arena_model1,
prompt,
state.image_aspect_ratio,
)
)
elif state.arena_model1.startswith(config.MODEL_GEMINI2):
logging.info("model 1: %s", state.arena_model1)
futures.append(
executor.submit(
generate_images,
prompt,
)
)
elif state.arena_model1.startswith(config.MODEL_FLUX1):
if config.MODEL_FLUX1_ENDPOINT_ID:
logging.info("model 1: %s", state.arena_model1)
futures.append(
executor.submit(
images_from_flux,
state.arena_model1,
prompt,
state.image_aspect_ratio,
)
)
else:
logging.error("no endpoint defined for %s", state.arena_model1)
elif state.arena_model1.startswith(config.MODEL_STABLE_DIFFUSION):
if config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID:
logging.info("model 1: %s", state.arena_model1)
futures.append(
executor.submit(
images_from_stable_diffusion,
state.arena_model1,
prompt,
state.image_aspect_ratio,
)
)
else:
logging.error("no endpoint defined for %s", state.arena_model1)
# model 2
if state.arena_model2 in IMAGEN_MODELS:
logging.info("model 2: %s", state.arena_model2)
futures.append(
executor.submit(
images_from_imagen,
state.arena_model2,
prompt,
state.image_aspect_ratio,
)
)
elif state.arena_model2.startswith(config.MODEL_GEMINI2):
logging.info("model 2: %s", state.arena_model2)
futures.append(
executor.submit(
generate_images,
prompt,
)
)
elif state.arena_model2.startswith(config.MODEL_FLUX1):
if config.MODEL_FLUX1_ENDPOINT_ID:
logging.info("model 2: %s", state.arena_model2)
futures.append(
executor.submit(
images_from_flux,
state.arena_model2,
prompt,
state.image_aspect_ratio,
)
)
else:
logging.error("no endpoint defined for %s", state.arena_model2)
elif state.arena_model2.startswith(config.MODEL_STABLE_DIFFUSION):
if config.MODEL_STABLE_DIFFUSION_ENDPOINT_ID:
logging.info("model 2: %s", state.arena_model2)
futures.append(
executor.submit(
images_from_stable_diffusion,
state.arena_model2,
prompt,
state.image_aspect_ratio,
)
)
else:
logging.error("no endpoint defined for %s", state.arena_model2)
# Fetch images from study
else:
futures.extend([
executor.submit(
study_fetch,
state.arena_model1,
prompt
),
executor.submit(
study_fetch,
state.arena_model2,
prompt
)
])
for future in as_completed(futures): # Wait for tasks to complete
try:
result = future.result() # Get the result of each task
state.arena_output.extend(
result
) # Assuming images_from_imagen returns a list
except Exception as e:
logging.error(f"Error during image generation: {e}")