backend-apis/app/routers/p2_content_creator.py (280 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Persona 2 routers - Content Creator
"""
import asyncio
import json
import tomllib
import numpy as np
from fastapi import APIRouter, HTTPException
from google.api_core.exceptions import GoogleAPICallError
from google.cloud import firestore
from google.cloud.exceptions import NotFound
from vertexai.vision_models import Image
from app.models.p2_model import (
DetectProductCategoriesRequest,
DetectProductCategoriesResponse,
EditImageRequest,
GenerateImageRequest,
GenerateOrEditImageResponse,
GenerateTitleDescriptionRequest,
GenerateTitleDescriptionResponse,
Product,
Service,
)
from app.utils import utils_gemini, utils_palm, utils_imagen, utils_vertex_vector
# ----------------------------------------------------------------------------#
# Load configuration file (config.toml) and global configs
with open("app/config.toml", "rb") as f:
config = tomllib.load(f)
project_id = config["global"]["project_id"]
# ----------------------------------------------------------------------------#
db = firestore.Client()
# ----------------------------------------------------------------------------#
# Vertex Vector multimodal search
embeddings_client = utils_palm.EmbeddingPredictionClient(project=project_id)
index_endpoint_id = config["multimodal"]["index_endpoint_id"]
deployed_index_id = config["multimodal"]["deployed_index_id"]
vector_api_endpoint = config["multimodal"]["vector_api_endpoint"]
# ----------------------------------------------------------------------------#
router = APIRouter(prefix="/p2", tags=["P2 - Content Creator"])
# -------------------------------DELETE---------------------------------------#
@router.delete(path="/user-product/{user_id}/{product_id}")
def delete_user_product(user_id: str, product_id: str) -> str:
"""
# Delete user product
## Path parameters
**user_id**: *string*
- User id
**product_id**: *string*
- Product id
## Returns
- ok
## Raises
**HTTPException** - *400* - Error deleting in Firestore
- Firestore could not delete the product
"""
try:
db.collection("content-creator").document(user_id).collection(
"products"
).document(product_id).delete()
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error deleting in Firestore" + str(e)
) from e
return "ok"
@router.delete(path="/user-service/{user_id}/{service_id}")
def delete_user_service(user_id: str, service_id: str) -> str:
"""
# Delete user service
## Path parameters
**user_id**: *string*
- User id
**service_id**: *string*
- Service id
## Returns
- ok
## Raises
**HTTPException** - *400* - Error deleting in Firestore
- Firestore could not delete the service
"""
try:
db.collection("content-creator").document(user_id).collection(
"services"
).document(service_id).delete()
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error deleting in Firestore" + str(e)
) from e
return "ok"
# ---------------------------------POST---------------------------------------#
@router.post(path="/user-product/{user_id}")
def add_user_product(user_id: str, product: Product) -> str:
"""
# Add user product
## Path Parameters [AddUserProductRequest]
**user_id**: *string*
- User id
## Product
**title**: *string*
- Title of the product
**description**: *string*
- Description of the product
**image_urls**: *list[string]*
- List of image urls of the product
**labels**: *list[string]*
- List of labels of the product
**features**: *list[string]*
- List of features of the product
**categories**: *list[string]*
- List of categories of the product
## Returns
- ok
## Raises
**HTTPException** - *400* - Error setting in Firestore
- Firestore could not set the product
"""
try:
db.collection("content-creator").document(user_id).collection(
"products"
).document().set(product.model_dump())
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error setting in Firestore" + str(e)
) from e
return "ok"
@router.post(path="/user-service/{user_id}")
def add_user_service(user_id: str, service: Service) -> str:
"""
# Add User Service
### Request body [AddUserServiceRequest]
**user_id**: *string*
- User id
**service**: *Service*
- Service to be added
### Service
**title**: *string*
- Title of the service
**description**: *string*
- Description of the service
**image_urls**: *list[string]*
- List of image urls of the service
**labels**: *list[string]*
- List of labels of the service
**features**: *list[string]*
- List of features of the service
**categories**: *list[string]*
- List of categories of the service
## Returns
- ok
## Raises
**HTTPException** - *400* - Error setting in Firestore
- Firestore could not set the service
"""
try:
db.collection("content-creator").document(user_id).collection(
"services"
).document().set(service.model_dump())
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error setting in Firestore" + str(e)
) from e
return "ok"
@router.post(path="/detect-product-categories")
def detect_product_categories(
data: DetectProductCategoriesRequest,
) -> DetectProductCategoriesResponse:
"""
# Detect Product Categories with Vision API
## Request body [DetectProductCategoriesRequest]
**images_uri**: *list*
- List of image uri
## Response body [DetectProductCategoriesResponse]
**vision_labels**: *list*
- List of labes from Vision API
**images_features**: *list*
- List of features from Imagen Captions and PaLM
**images_categories**: *list*
- List of categories from Imagen Captions and PaLM
**similar_products**: *list*
- List of similar products ids from Vertex Vector Search
## Raises
**HTTPException** - *400* - Error annotating images with Vision API
**HTTPException** - *400* - Error extracting captions
**HTTPException** - *400* - Error extracting features and categories
**HTTPException** - *400* - Error getting images embeddings
**HTTPException** - *400* - Error getting similar products
"""
if not data.images_names:
raise HTTPException(
status_code=400,
detail="Provide at least one image to generate the categories.",
)
try:
vision_labels = utils_imagen.annotate_image_names(data.images_names)
except GoogleAPICallError as e:
raise HTTPException(
status_code=400,
detail="Error annotating images with Vision API" + str(e),
) from e
try:
images_bytes = [
utils_imagen.image_name_to_bytes(image_name)
for image_name in data.images_names
]
# Extract labels with Imagen Captioning
imagen_captions = asyncio.run(
utils_imagen.run_image_captions(
images_bytes=images_bytes,
)
)
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error extracting captions" + str(e)
) from e
try:
images_feat_cat = asyncio.run(
utils_gemini.run_predict_text_llm(
prompts=[
config["content_creation"]["prompt_features"].format(
"\n".join(imagen_captions)
),
config["content_creation"]["prompt_categories"].format(
"\n".join(imagen_captions)
),
],
)
)
images_feat_cat[0] = images_feat_cat[0].replace("</output>", "")
images_feat_cat[0] = images_feat_cat[0].replace("```json", "")
images_feat_cat[0] = images_feat_cat[0].replace("```", "")
images_feat_cat[1] = images_feat_cat[1].replace("</output>", "")
images_feat_cat[1] = images_feat_cat[1].replace("```json", "")
images_feat_cat[1] = images_feat_cat[1].replace("```", "")
images_features = json.loads(images_feat_cat[0])
images_categories = json.loads(images_feat_cat[1])
except GoogleAPICallError as e:
raise HTTPException(
status_code=400,
detail="Error extracting features and categories" + str(e),
) from e
images_embeddings = []
try:
# Get similar products and retrieve their labels from Firestore
for image_bytes in images_bytes:
images_embeddings.append(
embeddings_client.get_embedding(
image_bytes=image_bytes
).image_embedding
)
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error getting images embeddings" + str(e)
) from e
image_embedding = list(np.sum(np.array(images_embeddings), axis=0))
try:
neighbors = utils_vertex_vector.find_neighbor(
feature_vector=image_embedding,
neighbor_count=2,
)
similar_products = []
if neighbors.nearest_neighbors is not None and len(neighbors.nearest_neighbors) > 0:
similar_products = [
i.datapoint.datapoint_id
for i in neighbors.nearest_neighbors[0].neighbors
]
except Exception as e:
raise HTTPException(
status_code=400, detail="Error getting similar products" + str(e)
) from e
return DetectProductCategoriesResponse(
vision_labels=list(vision_labels),
images_features=images_features["product_features"],
images_categories=images_categories["product_categories"],
similar_products=similar_products,
)
@router.post(path="/edit-image")
def image_edit(data: EditImageRequest) -> GenerateOrEditImageResponse:
"""
# Image editing with Imagen
## Request body [EditImageRequest]:
**prompt**: *string*
- Prompt for editing the image
**base_image_name**: *string*
- Base image name to be collected from Google Cloud Storage /images
**mask_image_name**: *string* = ""
- Mask image name to be collected from Google Cloud Storage /images
**number_of_images**: *int* = 1
- Number of images to generate
**negative_prompt**: *string* = ""
- Negative prompt for editing the image
## Response body [GenerateOrEditImageResponse]:
**generated_images**: *list[GeneratedImage]*
## GeneratedImage:
**image_name**: *string*
- Name of the image in Cloud Storage
**image_size**: *tuple(int, int)*
- Size of the generated image
**images_parameters**: *dict*
- Parameters used to generate the image
## Raises
**HTTPException** - *404* - Image not found in Cloud Storage
**HTTPException** - *400* - Error editing image with Imagen
"""
try:
if not data.mask_image_name:
mask = None
else:
mask = Image(
image_bytes=utils_imagen.image_name_to_bytes(
data.mask_image_name
)
)
base_image = Image(
image_bytes=utils_imagen.image_name_to_bytes(data.base_image_name)
)
except NotFound as e:
raise HTTPException(
status_code=404,
detail="Image not found in Cloud Storage " + str(e),
) from e
try:
imagen_responses = utils_imagen.image_generate_model.edit_image(
prompt=data.prompt,
base_image=base_image,
mask=mask,
number_of_images=data.number_of_images,
negative_prompt=data.negative_prompt,
)
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error editing image with Imagen " + str(e)
) from e
generated_images = []
for image in imagen_responses:
image_name = utils_imagen.upload_image_to_storage(
image._image_bytes # pylint: disable=protected-access
)
generated_images.append(
GenerateOrEditImageResponse.GeneratedImage(
image_name=image_name,
image_size=image._size, # pylint: disable=protected-access
image_parameters=image.generation_parameters,
)
)
return GenerateOrEditImageResponse(generated_images=generated_images)
@router.post(path="/generate-image")
def image_generate(
data: GenerateImageRequest,
) -> GenerateOrEditImageResponse:
"""
# Image generation with Imagen
## Request body [EditImageRequest]:
**prompt**: *string*
- Prompt for editing the image
**number_of_images**: *int* = 1
- Number of images to generate
**negative_prompt**: *string* = ""
- Negative prompt for editing the image
## Response body [GenerateOrEditImageResponse]:
**generated_images**: *list[GeneratedImage]*
## GeneratedImage:
**image_name**: *string*
- Name of the image in Cloud Storage
**image_size**: *tuple(int, int)*
- Size of the generated image
**images_parameters**: *dict*
- Parameters used to generate the image
## Raises
**HTTPException** - *400* - Error generating image with Imagen
"""
try:
imagen_responses = utils_imagen.image_generate_model.generate_images(
prompt=data.prompt,
number_of_images=data.number_of_images,
negative_prompt=data.negative_prompt,
)
except Exception as e:
raise HTTPException(
status_code=400,
detail="Error generating image with Imagen " + str(e),
) from e
generated_images = []
for image in imagen_responses:
image_name = utils_imagen.upload_image_to_storage(
image._image_bytes # pylint: disable=protected-access
)
generated_images.append(
GenerateOrEditImageResponse.GeneratedImage(
image_name=image_name,
image_size=image._size, # pylint: disable=protected-access
image_parameters=image.generation_parameters,
)
)
return GenerateOrEditImageResponse(generated_images=generated_images)
@router.post(path="/generate-title-description")
def generate_title_description(
data: GenerateTitleDescriptionRequest,
) -> GenerateTitleDescriptionResponse:
"""
# Generate Title and Description with PaLM
## Request body [GenerateTitleDescriptionRequest]
**product_categories**: *list*
- List of product categories
**context**: *string* = ""
- Context for the title and description
## Response body [GenerateTitleDescriptionResponse]
**title**: *string*
- Generated title
**description**: *string*
- Generated description
## Raises
**HTTPException** - *400* - Error generating title / description with PaLM
"""
try:
response = utils_gemini.generate_gemini_pro_text(
prompt=config["content_creation"][
"prompt_title_description"
].format(data.product_categories, data.context),
)
response = response.replace("</output>", "")
response = response.replace("```json", "")
response = response.replace("```", "")
response = json.loads(response)
except Exception as e:
raise HTTPException(
status_code=400,
detail="Error generating title / description with PaLM" + str(e),
) from e
return GenerateTitleDescriptionResponse(
title=response["title"], description=response["description"]
)
# ----------------------------------PUT---------------------------------------#
@router.put("/user-product/{user_id}/{product_id}")
def put_user_product(user_id: str, product_id: str, product: Product) -> str:
"""
# Put user product
## Path parameters
**user_id**: *string*
- User id
**product_id**: *string*
- Product id
## Returns
- ok
## Raises
**HTTPException** - *400* - Error setting in Firestore
- Firestore could not set the product
"""
try:
db.collection("content-creator").document(user_id).collection(
"products"
).document(product_id).set(product.model_dump())
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error setting in Firestore" + str(e)
) from e
return "ok"
@router.put("/user-service/{user_id}/{service_id}")
def put_user_service(user_id: str, service_id: str, service: Service) -> str:
"""
# Put user service
## Path parameters
**user_id**: *string*
- User id
**service_id**: *string*
- Service id
## Returns
- ok
## Raises
**HTTPException** - *400* - Error setting in Firestore
- Firestore could not set the service
"""
try:
db.collection("content-creator").document(user_id).collection(
"services"
).document(service_id).set(service.model_dump())
except GoogleAPICallError as e:
raise HTTPException(
status_code=400, detail="Error setting in Firestore" + str(e)
) from e
return "ok"