components/llm_service/src/routes/query.py (534 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable = broad-except """ Query endpoints """ import traceback from fastapi import APIRouter, Depends from common.models import (QueryEngine, User, UserQuery, QueryDocument) from common.models.llm_query import QE_TYPE_INTEGRATED_SEARCH from common.schemas.batch_job_schemas import BatchJobModel from common.utils.auth_service import validate_token from common.utils.batch_jobs import initiate_batch_job from common.utils.config import (JOB_TYPE_QUERY_ENGINE_BUILD, JOB_TYPE_QUERY_EXECUTE) from common.utils.errors import (ResourceNotFoundException, ValidationError, PayloadTooLargeError) from common.utils.http_exceptions import (InternalServerError, BadRequest, ResourceNotFound) from common.utils.logging_handler import Logger from config import (PROJECT_ID, DATABASE_PREFIX, PAYLOAD_FILE_SIZE, ERROR_RESPONSES, ENABLE_OPENAI_LLM, ENABLE_COHERE_LLM, DEFAULT_VECTOR_STORE, VECTOR_STORES, PG_HOST, ONEDRIVE_CLIENT_ID, ONEDRIVE_TENANT_ID) from schemas.llm_schema import (LLMQueryModel, LLMUserAllQueriesResponse, LLMUserQueryResponse, UserQueryUpdateModel, LLMQueryEngineModel, LLMGetQueryEnginesResponse, LLMQueryEngineURLResponse, LLMQueryEngineResponse, LLMQueryResponse, LLMGetVectorStoreTypesResponse) from services.query.query_service import (query_generate, delete_engine, update_user_query) from utils.gcs_helper import upload_b64files_to_gcs Logger = Logger.get_logger(__file__) router = APIRouter(prefix="/query", tags=["Query"], responses=ERROR_RESPONSES) @router.get( "", name="Get all Query engines", response_model=LLMGetQueryEnginesResponse) def get_engine_list(): """ Get available Query engines Returns: LLMGetQueryEnginesResponse """ query_engines = QueryEngine.fetch_all() query_engine_data = [{ "id": qe.id, "name": qe.name, "query_engine_type": qe.query_engine_type, "doc_url": qe.doc_url, "description": qe.description, "read_access_group": qe.read_access_group, "llm_type": qe.llm_type, "embedding_type": qe.embedding_type, "vector_store": qe.vector_store, "params": qe.params, "created_time": qe.created_time, "last_modified_time": qe.last_modified_time, } for qe in query_engines] try: return { "success": True, "message": "Successfully retrieved query engine types", "data": query_engine_data } except Exception as e: raise InternalServerError(str(e)) from e @router.get( "/vectorstore", name="Get supported vector store types", response_model=LLMGetVectorStoreTypesResponse) def get_vector_store_list(): """ Get available Vector Stores Returns: LLMGetVectorStoreTypesResponse """ try: return { "success": True, "message": "Successfully retrieved vector store types", "data": VECTOR_STORES } except Exception as e: raise InternalServerError(str(e)) from e @router.get( "/urls/{query_engine_id}", name="Get all URLs for a query engine", response_model=LLMQueryEngineURLResponse) def get_urls_for_query_engine(query_engine_id: str): """ Get all doc/web URLs for a Query Engine Args: query_engine_id (str): Returns: LLMQueryEngineURLResponse """ try: Logger.info(f"Get all URLs for a Query Engine={query_engine_id}") # other user queries q_engine = QueryEngine.find_by_id(query_engine_id) if q_engine is None: raise ResourceNotFoundException(f"Engine {query_engine_id} not found") query_docs = QueryDocument.find_by_query_engine_id(query_engine_id) url_list = list(map(lambda query_doc: query_doc.doc_url, query_docs)) return { "success": True, "message": "Successfully retrieved document URLs " f"for query engine {query_engine_id}", "data": url_list } except ResourceNotFoundException as e: raise ResourceNotFound(str(e)) from e except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e @router.get( "/engine/{query_engine_id}", name="Get details for query engine", response_model=LLMQueryEngineResponse) def get_query_engine(query_engine_id: str): """ Get details for a Query Engine Args: query_engine_id (str): Returns: LLMQueryEngineResponse """ try: Logger.info(f"Get details for a Query Engine={query_engine_id}") # get engine model q_engine = QueryEngine.find_by_id(query_engine_id) if q_engine is None: raise ResourceNotFoundException(f"Engine {query_engine_id} not found") # get query docs query_docs = QueryDocument.find_by_query_engine_id(query_engine_id) url_list = list(map(lambda query_doc: query_doc.doc_url, query_docs)) response_data = q_engine.get_fields(reformat_datetime=True) response_data["url_list"] = url_list return { "success": True, "message": "Successfully retrieved details " f"for query engine {query_engine_id}", "data": response_data } except ResourceNotFoundException as e: raise ResourceNotFound(str(e)) from e except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e @router.get( "/user", name="Get all Queries for current logged-in user", response_model=LLMUserAllQueriesResponse) def get_query_list(skip: int = 0, limit: int = 20, user_data: dict = Depends(validate_token)): """ Get user queries for authenticated user. Query data does not include history to slim payload. To retrieve query history use the get single query endpoint. Args: user_id (str): skip (int): Number of tools to be skipped <br/> limit (int): Size of tools array to be returned <br/> Returns: LLMUserAllQueriesResponse """ try: user_email = user_data.get("email") Logger.info(f"Get all Queries for a user={user_email}") if skip < 0: raise ValidationError("Invalid value passed to \"skip\" query parameter") if limit < 1: raise ValidationError("Invalid value passed to \"limit\" query parameter") user = User.find_by_email(user_email) if user is None: raise ResourceNotFoundException(f"User {user_email} not found ") user_queries = UserQuery.find_by_user(user.id, skip=skip, limit=limit) query_list = [] for i in user_queries: query_data = i.get_fields(reformat_datetime=True) query_data["id"] = i.id # don't include chat history to slim return payload del query_data["history"] query_list.append(query_data) Logger.info(f"Successfully retrieved {len(query_list)} user queries.") return { "success": True, "message": f"Successfully retrieved user queries for user {user.id}", "data": query_list } except ValidationError as e: raise BadRequest(str(e)) from e except ResourceNotFoundException as e: raise ResourceNotFound(str(e)) from e except Exception as e: raise InternalServerError(str(e)) from e @router.get( "/{query_id}", name="Get user query", response_model=LLMUserQueryResponse) def get_query(query_id: str): """ Get a specific user query by id Returns: LLMUserQueryResponse """ try: Logger.info(f"Get a specific user query by id={query_id}") user_query = UserQuery.find_by_id(query_id) query_data = user_query.get_fields(reformat_datetime=True) query_data["id"] = user_query.id Logger.info(f"Successfully retrieved user query {query_id}") return { "success": True, "message": f"Successfully retrieved user query {query_id}", "data": query_data } except ValidationError as e: raise BadRequest(str(e)) from e except ResourceNotFoundException as e: raise ResourceNotFound(str(e)) from e except Exception as e: raise InternalServerError(str(e)) from e @router.put( "/{query_id}", name="Update user query" ) def update_query(query_id: str, input_query: UserQueryUpdateModel): """Update a user query Args: query_id (str): Query ID input_query (UserQueryUpdateModel): fields in body of query to update. The only field that can be updated is the title. Raises: ResourceNotFoundException: If the UserQuery does not exist HTTPException: 500 Internal Server Error if something fails Returns: [JSON]: {'success': 'True'} if the user query is updated, NotFoundErrorResponseModel if the user query not found, InternalServerErrorResponseModel if the update raises an exception """ Logger.info(f"Update a user query by id={query_id}") existing_query = UserQuery.find_by_id(query_id) if existing_query is None: raise ResourceNotFoundException(f"Query {query_id} not found") try: input_query_dict = {**input_query.dict()} for key in input_query_dict: if input_query_dict.get(key) is not None: setattr(existing_query, key, input_query_dict.get(key)) existing_query.update() return { "success": True, "message": f"Successfully updated user query {query_id}", } except ResourceNotFoundException as re: raise ResourceNotFound(str(re)) from re except Exception as e: Logger.error(e) raise InternalServerError(str(e)) from e @router.delete( "/{query_id}", name="Delete user query" ) def delete_query(query_id: str, hard_delete: bool = True): """Delete a user query. By default we do a hard delete. Args: query_id (str): Query ID Raises: ResourceNotFoundException: If the UserQuery does not exist HTTPException: 500 Internal Server Error if something fails Returns: [JSON]: {'success': 'True'} if the user query is deleted, NotFoundErrorResponseModel if the user query not found, InternalServerErrorResponseModel if the update raises an exception """ Logger.info(f"Delete a user query by id={query_id} hard_delete={hard_delete}") existing_query = UserQuery.find_by_id(query_id) if existing_query is None: raise ResourceNotFoundException(f"Query {query_id} not found") try: if hard_delete: UserQuery.delete_by_id(existing_query.id) else: UserQuery.soft_delete_by_id(existing_query.id) return { "success": True, "message": f"Successfully deleted user query {query_id}", } except ResourceNotFoundException as re: raise ResourceNotFound(str(re)) from re except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e @router.put( "/engine/{query_engine_id}", name="Update a query engine") def update_query_engine(query_engine_id: str, data_config: LLMQueryEngineModel): """ Update a query engine. It only supports updating description and read access group. Args: query_engine_id (LLMQueryEngineModel) Returns: [JSON]: {'success': 'True'} if the query engine is deleted, ResourceNotFoundException if the query engine not found, InternalServerErrorResponseModel if the deletion raises an exception """ if query_engine_id is None or query_engine_id == "": return BadRequest("Missing or invalid payload parameters: query_engine_id") q_engine = QueryEngine.find_by_id(query_engine_id) if q_engine is None: raise ResourceNotFoundException(f"Engine {query_engine_id} not found") data_dict = {**data_config.dict()} try: Logger.info(f"Updating q_engine=[{q_engine.name}]") q_engine.description = data_dict["description"] q_engine.read_access_group = data_dict["read_access_group"] q_engine.save() Logger.info(f"Successfully updated q_engine=[{q_engine.name}]") except Exception as e: Logger.error(e) raise InternalServerError(str(e)) from e return { "success": True, "message": f"Successfully deleted query engine {query_engine_id}", } @router.delete( "/engine/{query_engine_id}", name="Delete a query engine") def delete_query_engine(query_engine_id: str, hard_delete: bool = True): """ Delete a query engine. By default we do a hard delete. Args: query_engine_id (LLMQueryEngineModel) hard_delete (boolean) Returns: [JSON]: {'success': 'True'} if the query engine is deleted, ResourceNotFoundException if the query engine not found, InternalServerErrorResponseModel if the deletion raises an exception """ if query_engine_id is None or query_engine_id == "": return BadRequest("Missing or invalid payload parameters: query_engine_id") q_engine = QueryEngine.find_by_id(query_engine_id) if q_engine is None: raise ResourceNotFoundException(f"Engine {query_engine_id} not found") try: Logger.info( f"Deleting q_engine=[{q_engine.name}] hard_delete=[{hard_delete}]") delete_engine(q_engine, hard_delete=hard_delete) Logger.info(f"Successfully deleted q_engine=[{q_engine.name}]") except Exception as e: Logger.error(e) raise InternalServerError(str(e)) from e return { "success": True, "message": f"Successfully deleted query engine {query_engine_id}", } @router.post( "/engine", name="Create a query engine", response_model=BatchJobModel) async def query_engine_create(gen_config: LLMQueryEngineModel, user_data: dict = Depends(validate_token)): """ Start a query engine build job Args: gen_config (LLMQueryEngineModel) user_data (dict) Returns: BatchJobModel """ genconfig_dict = {**gen_config.dict()} Logger.info(f"Create a query engine with {genconfig_dict}") doc_url = genconfig_dict.get("doc_url") documents = genconfig_dict.get("documents") query_engine_type = genconfig_dict.get("query_engine_type", None) if query_engine_type != QE_TYPE_INTEGRATED_SEARCH: if documents: doc_url = f"gs://{(await upload_b64files_to_gcs(documents)).name}" # validate doc_url if doc_url is None or doc_url == "": return BadRequest("Missing or invalid payload parameters: doc_url") if not (doc_url.startswith("gs://") or doc_url.startswith("http://") or doc_url.startswith("https://") or doc_url.startswith("bq://") or doc_url.startswith("shpt://")): return BadRequest( "doc_url must start with gs://, http:// or https://, bq://, shpt://" f" instead got {doc_url}") if doc_url.endswith(".pdf"): return BadRequest( "doc_url must point to a GCS bucket/folder or website, not a document") query_engine = genconfig_dict.get("query_engine") if query_engine is None or query_engine == "": return BadRequest("Missing or invalid payload parameters: query_engine") q_engine = QueryEngine.find_by_name(query_engine) if q_engine: return BadRequest(f"Query engine already exists: {query_engine}") user_id = user_data.get("user_id") params = genconfig_dict.get("params", {}) try: data = { "doc_url": doc_url, "query_engine": query_engine, "user_id": user_id, "query_engine_type": query_engine_type, "llm_type": genconfig_dict.get("llm_type", None), "embedding_type": genconfig_dict.get("embedding_type", None), "vector_store": genconfig_dict.get("vector_store", None), "description": genconfig_dict.get("description", None), "params": params, } env_vars = { "DATABASE_PREFIX": DATABASE_PREFIX, "PROJECT_ID": PROJECT_ID, "ENABLE_OPENAI_LLM": str(ENABLE_OPENAI_LLM), "ENABLE_COHERE_LLM": str(ENABLE_COHERE_LLM), "DEFAULT_VECTOR_STORE": str(DEFAULT_VECTOR_STORE), "PG_HOST": PG_HOST, "ONEDRIVE_CLIENT_ID": ONEDRIVE_CLIENT_ID, "ONEDRIVE_TENANT_ID": ONEDRIVE_TENANT_ID, } response = initiate_batch_job(data, JOB_TYPE_QUERY_ENGINE_BUILD, env_vars) Logger.info(f"Batch job response: {response}") return response except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e @router.post( "/engine/{query_engine_id}", name="Make a query to a query engine", response_model=LLMQueryResponse) async def query(query_engine_id: str, gen_config: LLMQueryModel, user_data: dict = Depends(validate_token)): """ Send a query to a query engine and return the response Args: query_engine_id (str): gen_config (LLMQueryModel): user_data (dict): Returns: LLMQueryResponse """ Logger.info(f"Using query engine with " f"query_engine_id=[{query_engine_id}] and {gen_config}") q_engine = QueryEngine.find_by_id(query_engine_id) if q_engine is None: raise ResourceNotFoundException(f"Engine {query_engine_id} not found") genconfig_dict = {**gen_config.dict()} prompt = genconfig_dict.get("prompt") if prompt is None or prompt == "": return BadRequest("Missing or invalid payload parameters") if len(prompt) > PAYLOAD_FILE_SIZE: return PayloadTooLargeError( f"Prompt must be less than {PAYLOAD_FILE_SIZE}") llm_type = genconfig_dict.get("llm_type") rank_sentences = genconfig_dict.get("rank_sentences", False) Logger.info(f"rank_sentences = {rank_sentences}") query_filter = genconfig_dict.get("query_filter") Logger.info(f"query_filter = {query_filter}") # get the User GENIE stores user = User.find_by_email(user_data.get("email")) run_as_batch_job = genconfig_dict.get("run_as_batch_job", False) Logger.info(f"run_as_batch_job = {run_as_batch_job}") user_query = None if run_as_batch_job: # create user query object to hold the query state user_query = UserQuery(user_id=user.user_id, prompt=prompt, query_engine_id=q_engine.id) user_query.save() user_query.update_history(prompt=prompt) query_data = user_query.get_fields(reformat_datetime=True) query_data["id"] = user_query.id # launch batch job to perform query try: data = { "query_engine_id": query_engine_id, "prompt": prompt, "llm_type": llm_type, "user_id": user.id, "user_query_id": user_query.id, "rank_sentences": rank_sentences, "query_filter": query_filter } env_vars = { "DATABASE_PREFIX": DATABASE_PREFIX, "PROJECT_ID": PROJECT_ID, "ENABLE_OPENAI_LLM": str(ENABLE_OPENAI_LLM), "ENABLE_COHERE_LLM": str(ENABLE_COHERE_LLM), "DEFAULT_VECTOR_STORE": str(DEFAULT_VECTOR_STORE), "PG_HOST": PG_HOST, } response = initiate_batch_job(data, JOB_TYPE_QUERY_EXECUTE, env_vars) Logger.info(f"Batch job response: {response}") return { "success": True, "message": "Successfully ran query in batch mode", "data": { "query": query_data, "batch_job": response["data"], }, } except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e # perform normal synchronous query try: query_result, query_references = await query_generate(user.id, prompt, q_engine, user_data, llm_type, user_query, rank_sentences, query_filter) Logger.info(f"Query response=" f"[{query_result.response}]") # save user query history user_query, query_reference_dicts = \ update_user_query(prompt, query_result.response, user.id, q_engine, query_references, None, query_filter) query_result_dict = query_result.get_fields(reformat_datetime=True) return { "success": True, "message": "Successfully generated text", "data": { "user_query_id": user_query.id, "query_result": query_result_dict, "query_references": query_reference_dicts } } except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e @router.post( "/{user_query_id}", name="Continue chat with a prior user query", response_model=LLMQueryResponse) async def query_continue( user_query_id: str, gen_config: LLMQueryModel, user_data: dict = Depends(validate_token)): """ Continue a prior user query. Perform a new search and add those references along with prior query/chat history as context. Args: user_query_id (str): id of previous user query gen_config (LLMQueryModel) Returns: LLMQueryResponse """ Logger.info("Using query engine based on a prior user query " f"user_query_id={user_query_id}, gen_config={gen_config}") user_query = UserQuery.find_by_id(user_query_id) if user_query is None: raise ResourceNotFoundException(f"Query {user_query_id} not found") genconfig_dict = {**gen_config.dict()} prompt = genconfig_dict.get("prompt") if prompt is None or prompt == "": return BadRequest("Missing or invalid payload parameters") if len(prompt) > PAYLOAD_FILE_SIZE: return PayloadTooLargeError( f"Prompt must be less than {PAYLOAD_FILE_SIZE}") llm_type = genconfig_dict.get("llm_type") rank_sentences = genconfig_dict.get("rank_sentences", False) Logger.info(f"rank_sentences = {rank_sentences}") query_filter = genconfig_dict.get("query_filter") Logger.info(f"query_filter = {query_filter}") q_engine = QueryEngine.find_by_id(user_query.query_engine_id) run_as_batch_job = genconfig_dict.get("run_as_batch_job", False) if run_as_batch_job: # launch batch job to perform query try: data = { "query_engine_id": q_engine.id, "prompt": prompt, "llm_type": llm_type, "user_id": user_query.user_id, "user_query_id": user_query.id, "rank_sentences": rank_sentences } env_vars = { "DATABASE_PREFIX": DATABASE_PREFIX, "PROJECT_ID": PROJECT_ID, "ENABLE_OPENAI_LLM": str(ENABLE_OPENAI_LLM), "ENABLE_COHERE_LLM": str(ENABLE_COHERE_LLM), "DEFAULT_VECTOR_STORE": str(DEFAULT_VECTOR_STORE), "PG_HOST": PG_HOST, } response = initiate_batch_job(data, JOB_TYPE_QUERY_EXECUTE, env_vars) Logger.info(f"Batch job response: {response}") query_data = user_query.get_fields(reformat_datetime=True) query_data["id"] = user_query.id return { "success": True, "message": "Successfully ran query in batch mode", "data": { "query": query_data, "batch_job": response["data"], }, } except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e # perform normal synchronous query try: query_result, query_references = await query_generate(user_query.user_id, prompt, q_engine, user_data, llm_type, user_query, rank_sentences, query_filter) # save user query history _, query_reference_dicts = \ update_user_query(prompt, query_result.response, user_query.user_id, q_engine, query_references) Logger.info(f"Generated query response=" f"[{query_result.response}], " f"query_result={query_result} " f"query_references={[repr(qe) for qe in query_references]}") query_result_dict = query_result.get_fields(reformat_datetime=True) return { "success": True, "message": "Successfully generated text", "data": { "user_query_id": user_query.id, "query_result": query_result_dict, "query_references": query_reference_dicts } } except Exception as e: Logger.error(e) Logger.error(traceback.print_exc()) raise InternalServerError(str(e)) from e