in experiments/arena/models/generate.py [0:0]
def images_from_imagen(model_name: str, prompt: str, aspect_ratio: str):
"""creates images from Imagen and returns a list of gcs uris
Args:
model_name (str): imagen model name
prompt (str): prompt for t2i model
aspect_ratio (str): aspect ratio string
Returns:
_type_: a list of strings (gcs uris of image output)
"""
start_time = time.time()
arena_output = []
logging.info(f"model: {model_name}")
logging.info(f"prompt: {prompt}")
logging.info(f"target output: {config.GENMEDIA_BUCKET}")
vertexai.init(project=config.PROJECT_ID, location=config.LOCATION)
image_model = ImageGenerationModel.from_pretrained(model_name)
response = image_model.generate_images(
prompt=prompt,
add_watermark=True,
# aspect_ratio=getattr(state, "image_aspect_ratio"),
aspect_ratio=aspect_ratio,
number_of_images=1,
output_gcs_uri=f"gs://{config.GENMEDIA_BUCKET}/imagen_live",
language="auto",
# negative_prompt=state.image_negative_prompt_input,
safety_filter_level="block_few",
# include_rai_reason=True,
)
end_time = time.time()
elapsed_time = end_time - start_time
for idx, img in enumerate(response.images):
logging.info(f"Generated image {idx} with model {model_name} in {elapsed_time:.2f} seconds")
logging.info(
f"Generated image: #{idx}, len {len(img._as_base64_string())} at {img._gcs_uri}"
)
# output = img._as_base64_string()
# state.image_output.append(output)
arena_output.append(img._gcs_uri)
logging.info(f"Image created: {img._gcs_uri}")
try:
add_image_metadata(img._gcs_uri, prompt, model_name)
except Exception as e:
if "DeadlineExceeded" in str(e): # Check for timeout error
logging.error(f"Firestore timeout: {e}")
else:
logging.error(f"Error adding image metadata: {e}")
return arena_output