backend-apis/app/routers/p2_content_creator.py (280 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Persona 2 routers - Content Creator """ import asyncio import json import tomllib import numpy as np from fastapi import APIRouter, HTTPException from google.api_core.exceptions import GoogleAPICallError from google.cloud import firestore from google.cloud.exceptions import NotFound from vertexai.vision_models import Image from app.models.p2_model import ( DetectProductCategoriesRequest, DetectProductCategoriesResponse, EditImageRequest, GenerateImageRequest, GenerateOrEditImageResponse, GenerateTitleDescriptionRequest, GenerateTitleDescriptionResponse, Product, Service, ) from app.utils import utils_gemini, utils_palm, utils_imagen, utils_vertex_vector # ----------------------------------------------------------------------------# # Load configuration file (config.toml) and global configs with open("app/config.toml", "rb") as f: config = tomllib.load(f) project_id = config["global"]["project_id"] # ----------------------------------------------------------------------------# db = firestore.Client() # ----------------------------------------------------------------------------# # Vertex Vector multimodal search embeddings_client = utils_palm.EmbeddingPredictionClient(project=project_id) index_endpoint_id = config["multimodal"]["index_endpoint_id"] deployed_index_id = config["multimodal"]["deployed_index_id"] vector_api_endpoint = config["multimodal"]["vector_api_endpoint"] # ----------------------------------------------------------------------------# router = APIRouter(prefix="/p2", tags=["P2 - Content Creator"]) # -------------------------------DELETE---------------------------------------# @router.delete(path="/user-product/{user_id}/{product_id}") def delete_user_product(user_id: str, product_id: str) -> str: """ # Delete user product ## Path parameters **user_id**: *string* - User id **product_id**: *string* - Product id ## Returns - ok ## Raises **HTTPException** - *400* - Error deleting in Firestore - Firestore could not delete the product """ try: db.collection("content-creator").document(user_id).collection( "products" ).document(product_id).delete() except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error deleting in Firestore" + str(e) ) from e return "ok" @router.delete(path="/user-service/{user_id}/{service_id}") def delete_user_service(user_id: str, service_id: str) -> str: """ # Delete user service ## Path parameters **user_id**: *string* - User id **service_id**: *string* - Service id ## Returns - ok ## Raises **HTTPException** - *400* - Error deleting in Firestore - Firestore could not delete the service """ try: db.collection("content-creator").document(user_id).collection( "services" ).document(service_id).delete() except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error deleting in Firestore" + str(e) ) from e return "ok" # ---------------------------------POST---------------------------------------# @router.post(path="/user-product/{user_id}") def add_user_product(user_id: str, product: Product) -> str: """ # Add user product ## Path Parameters [AddUserProductRequest] **user_id**: *string* - User id ## Product **title**: *string* - Title of the product **description**: *string* - Description of the product **image_urls**: *list[string]* - List of image urls of the product **labels**: *list[string]* - List of labels of the product **features**: *list[string]* - List of features of the product **categories**: *list[string]* - List of categories of the product ## Returns - ok ## Raises **HTTPException** - *400* - Error setting in Firestore - Firestore could not set the product """ try: db.collection("content-creator").document(user_id).collection( "products" ).document().set(product.model_dump()) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error setting in Firestore" + str(e) ) from e return "ok" @router.post(path="/user-service/{user_id}") def add_user_service(user_id: str, service: Service) -> str: """ # Add User Service ### Request body [AddUserServiceRequest] **user_id**: *string* - User id **service**: *Service* - Service to be added ### Service **title**: *string* - Title of the service **description**: *string* - Description of the service **image_urls**: *list[string]* - List of image urls of the service **labels**: *list[string]* - List of labels of the service **features**: *list[string]* - List of features of the service **categories**: *list[string]* - List of categories of the service ## Returns - ok ## Raises **HTTPException** - *400* - Error setting in Firestore - Firestore could not set the service """ try: db.collection("content-creator").document(user_id).collection( "services" ).document().set(service.model_dump()) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error setting in Firestore" + str(e) ) from e return "ok" @router.post(path="/detect-product-categories") def detect_product_categories( data: DetectProductCategoriesRequest, ) -> DetectProductCategoriesResponse: """ # Detect Product Categories with Vision API ## Request body [DetectProductCategoriesRequest] **images_uri**: *list* - List of image uri ## Response body [DetectProductCategoriesResponse] **vision_labels**: *list* - List of labes from Vision API **images_features**: *list* - List of features from Imagen Captions and PaLM **images_categories**: *list* - List of categories from Imagen Captions and PaLM **similar_products**: *list* - List of similar products ids from Vertex Vector Search ## Raises **HTTPException** - *400* - Error annotating images with Vision API **HTTPException** - *400* - Error extracting captions **HTTPException** - *400* - Error extracting features and categories **HTTPException** - *400* - Error getting images embeddings **HTTPException** - *400* - Error getting similar products """ if not data.images_names: raise HTTPException( status_code=400, detail="Provide at least one image to generate the categories.", ) try: vision_labels = utils_imagen.annotate_image_names(data.images_names) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error annotating images with Vision API" + str(e), ) from e try: images_bytes = [ utils_imagen.image_name_to_bytes(image_name) for image_name in data.images_names ] # Extract labels with Imagen Captioning imagen_captions = asyncio.run( utils_imagen.run_image_captions( images_bytes=images_bytes, ) ) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error extracting captions" + str(e) ) from e try: images_feat_cat = asyncio.run( utils_gemini.run_predict_text_llm( prompts=[ config["content_creation"]["prompt_features"].format( "\n".join(imagen_captions) ), config["content_creation"]["prompt_categories"].format( "\n".join(imagen_captions) ), ], ) ) images_feat_cat[0] = images_feat_cat[0].replace("</output>", "") images_feat_cat[0] = images_feat_cat[0].replace("```json", "") images_feat_cat[0] = images_feat_cat[0].replace("```", "") images_feat_cat[1] = images_feat_cat[1].replace("</output>", "") images_feat_cat[1] = images_feat_cat[1].replace("```json", "") images_feat_cat[1] = images_feat_cat[1].replace("```", "") images_features = json.loads(images_feat_cat[0]) images_categories = json.loads(images_feat_cat[1]) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error extracting features and categories" + str(e), ) from e images_embeddings = [] try: # Get similar products and retrieve their labels from Firestore for image_bytes in images_bytes: images_embeddings.append( embeddings_client.get_embedding( image_bytes=image_bytes ).image_embedding ) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error getting images embeddings" + str(e) ) from e image_embedding = list(np.sum(np.array(images_embeddings), axis=0)) try: neighbors = utils_vertex_vector.find_neighbor( feature_vector=image_embedding, neighbor_count=2, ) similar_products = [] if neighbors.nearest_neighbors is not None and len(neighbors.nearest_neighbors) > 0: similar_products = [ i.datapoint.datapoint_id for i in neighbors.nearest_neighbors[0].neighbors ] except Exception as e: raise HTTPException( status_code=400, detail="Error getting similar products" + str(e) ) from e return DetectProductCategoriesResponse( vision_labels=list(vision_labels), images_features=images_features["product_features"], images_categories=images_categories["product_categories"], similar_products=similar_products, ) @router.post(path="/edit-image") def image_edit(data: EditImageRequest) -> GenerateOrEditImageResponse: """ # Image editing with Imagen ## Request body [EditImageRequest]: **prompt**: *string* - Prompt for editing the image **base_image_name**: *string* - Base image name to be collected from Google Cloud Storage /images **mask_image_name**: *string* = "" - Mask image name to be collected from Google Cloud Storage /images **number_of_images**: *int* = 1 - Number of images to generate **negative_prompt**: *string* = "" - Negative prompt for editing the image ## Response body [GenerateOrEditImageResponse]: **generated_images**: *list[GeneratedImage]* ## GeneratedImage: **image_name**: *string* - Name of the image in Cloud Storage **image_size**: *tuple(int, int)* - Size of the generated image **images_parameters**: *dict* - Parameters used to generate the image ## Raises **HTTPException** - *404* - Image not found in Cloud Storage **HTTPException** - *400* - Error editing image with Imagen """ try: if not data.mask_image_name: mask = None else: mask = Image( image_bytes=utils_imagen.image_name_to_bytes( data.mask_image_name ) ) base_image = Image( image_bytes=utils_imagen.image_name_to_bytes(data.base_image_name) ) except NotFound as e: raise HTTPException( status_code=404, detail="Image not found in Cloud Storage " + str(e), ) from e try: imagen_responses = utils_imagen.image_generate_model.edit_image( prompt=data.prompt, base_image=base_image, mask=mask, number_of_images=data.number_of_images, negative_prompt=data.negative_prompt, ) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error editing image with Imagen " + str(e) ) from e generated_images = [] for image in imagen_responses: image_name = utils_imagen.upload_image_to_storage( image._image_bytes # pylint: disable=protected-access ) generated_images.append( GenerateOrEditImageResponse.GeneratedImage( image_name=image_name, image_size=image._size, # pylint: disable=protected-access image_parameters=image.generation_parameters, ) ) return GenerateOrEditImageResponse(generated_images=generated_images) @router.post(path="/generate-image") def image_generate( data: GenerateImageRequest, ) -> GenerateOrEditImageResponse: """ # Image generation with Imagen ## Request body [EditImageRequest]: **prompt**: *string* - Prompt for editing the image **number_of_images**: *int* = 1 - Number of images to generate **negative_prompt**: *string* = "" - Negative prompt for editing the image ## Response body [GenerateOrEditImageResponse]: **generated_images**: *list[GeneratedImage]* ## GeneratedImage: **image_name**: *string* - Name of the image in Cloud Storage **image_size**: *tuple(int, int)* - Size of the generated image **images_parameters**: *dict* - Parameters used to generate the image ## Raises **HTTPException** - *400* - Error generating image with Imagen """ try: imagen_responses = utils_imagen.image_generate_model.generate_images( prompt=data.prompt, number_of_images=data.number_of_images, negative_prompt=data.negative_prompt, ) except Exception as e: raise HTTPException( status_code=400, detail="Error generating image with Imagen " + str(e), ) from e generated_images = [] for image in imagen_responses: image_name = utils_imagen.upload_image_to_storage( image._image_bytes # pylint: disable=protected-access ) generated_images.append( GenerateOrEditImageResponse.GeneratedImage( image_name=image_name, image_size=image._size, # pylint: disable=protected-access image_parameters=image.generation_parameters, ) ) return GenerateOrEditImageResponse(generated_images=generated_images) @router.post(path="/generate-title-description") def generate_title_description( data: GenerateTitleDescriptionRequest, ) -> GenerateTitleDescriptionResponse: """ # Generate Title and Description with PaLM ## Request body [GenerateTitleDescriptionRequest] **product_categories**: *list* - List of product categories **context**: *string* = "" - Context for the title and description ## Response body [GenerateTitleDescriptionResponse] **title**: *string* - Generated title **description**: *string* - Generated description ## Raises **HTTPException** - *400* - Error generating title / description with PaLM """ try: response = utils_gemini.generate_gemini_pro_text( prompt=config["content_creation"][ "prompt_title_description" ].format(data.product_categories, data.context), ) response = response.replace("</output>", "") response = response.replace("```json", "") response = response.replace("```", "") response = json.loads(response) except Exception as e: raise HTTPException( status_code=400, detail="Error generating title / description with PaLM" + str(e), ) from e return GenerateTitleDescriptionResponse( title=response["title"], description=response["description"] ) # ----------------------------------PUT---------------------------------------# @router.put("/user-product/{user_id}/{product_id}") def put_user_product(user_id: str, product_id: str, product: Product) -> str: """ # Put user product ## Path parameters **user_id**: *string* - User id **product_id**: *string* - Product id ## Returns - ok ## Raises **HTTPException** - *400* - Error setting in Firestore - Firestore could not set the product """ try: db.collection("content-creator").document(user_id).collection( "products" ).document(product_id).set(product.model_dump()) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error setting in Firestore" + str(e) ) from e return "ok" @router.put("/user-service/{user_id}/{service_id}") def put_user_service(user_id: str, service_id: str, service: Service) -> str: """ # Put user service ## Path parameters **user_id**: *string* - User id **service_id**: *string* - Service id ## Returns - ok ## Raises **HTTPException** - *400* - Error setting in Firestore - Firestore could not set the service """ try: db.collection("content-creator").document(user_id).collection( "services" ).document(service_id).set(service.model_dump()) except GoogleAPICallError as e: raise HTTPException( status_code=400, detail="Error setting in Firestore" + str(e) ) from e return "ok"