webhook/main.py (157 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
from google.cloud import logging
from typing import Mapping
import google.auth.transport.requests
import google.oauth2.id_token
import requests
import flask
from bigquery import write_summarization_to_table
from document_extract import async_document_extract
from storage import upload_to_gcs
from vertex_llm import predict_large_language_model
from utils import coerce_datetime_zulu, truncate_complete_text
_FUNCTIONS_VERTEX_EVENT_LOGGER = 'summarization-by-llm'
_PROJECT_ID = os.environ["PROJECT_ID"]
_OUTPUT_BUCKET = os.environ["OUTPUT_BUCKET"]
_LOCATION = os.environ["LOCATION"]
_MODEL_NAME = "text-bison@001"
_DEFAULT_PARAMETERS = {
"temperature": 0.2,
"max_output_tokens": 256,
"top_p": 0.95,
"top_k": 40,
}
_DATASET_ID = os.environ["DATASET_ID"]
_TABLE_ID = os.environ["TABLE_ID"]
def default_marshaller(o: object) -> str:
if isinstance(o, (datetime.date, datetime.datetime)):
return o.isoformat()
return str(o)
def redirect_and_reply(previous_data):
endpoint = f'https://{_LOCATION}-{_PROJECT_ID}.cloudfunctions.net/{os.environ["K_SERVICE"]}'
logging_client = logging.Client()
logger = logging_client.logger(_FUNCTIONS_VERTEX_EVENT_LOGGER)
auth_req = google.auth.transport.requests.Request()
id_token = google.oauth2.id_token.fetch_id_token(auth_req, endpoint)
data = {
'name': previous_data["name"],
'id': previous_data["id"],
'bucket': previous_data["bucket"],
'timeCreated': previous_data["timeCreated"],
}
headers = {}
headers["Authorization"] = f"Bearer {id_token}"
logger.log(f'TRIGGERING JOB FLOW: {endpoint}')
try:
requests.post(
endpoint,
json=data,
timeout=1,
headers=headers,
)
except requests.exceptions.Timeout:
return flask.Response(status=200)
except Exception:
return flask.Response(status=500)
return flask.Response(status=200)
def entrypoint(request: object) -> Mapping[str, str]:
data = request.get_json()
if data.get("kind", None) == "storage#object":
# Entrypoint called by Pub-Sub (Eventarc)
return redirect_and_reply(data)
if 'bucket' in data:
# Entrypoint called by REST (possibly by redirect_and_replay)
return cloud_event_entrypoint(
name=data["name"],
event_id=data["id"],
bucket=data["bucket"],
time_created=coerce_datetime_zulu(data["timeCreated"]),
)
if "text" in data:
# Entrypoint called by REST.
return summarization_entrypoint(
name=data["name"],
extracted_text=data["text"],
time_created=datetime.datetime.now(datetime.timezone.utc),
event_id="CURL_TRIGGER",
)
return flask.Response(status=500)
def cloud_event_entrypoint(event_id, bucket, name, time_created):
orig_pdf_uri = f"gs://{bucket}/{name}"
logging_client = logging.Client()
logger = logging_client.logger(_FUNCTIONS_VERTEX_EVENT_LOGGER)
logger.log(f"cloud_event_id({event_id}): UPLOAD {orig_pdf_uri}",
severity="INFO")
extracted_text = async_document_extract(bucket, name, output_bucket=_OUTPUT_BUCKET)
logger.log(
f"cloud_event_id({event_id}): OCR gs://{bucket}/{name}", severity="INFO"
)
return summarization_entrypoint(
name,
extracted_text,
time_created=time_created,
event_id=event_id,
bucket=bucket,
)
def summarization_entrypoint(
name: str,
extracted_text: str,
time_created: datetime.time,
bucket: str = None,
event_id: str = None,
) -> Mapping[str, str]:
logging_client = logging.Client()
logger = logging_client.logger(_FUNCTIONS_VERTEX_EVENT_LOGGER)
if len(extracted_text) == 0:
logger.log(f"""cloud_event_id({event_id}): BAD INPUT
No characters recognized from PDF and so the PDF cannot be
summarized. Be sure to upload a high-quality PDF that contains 'Abstract' and
'Conclusion' sections.
""", severity="ERROR")
return ""
complete_text_filename = f'summaries/{name.replace(".pdf", "")}_fulltext.txt'
upload_to_gcs(
_OUTPUT_BUCKET,
complete_text_filename,
extracted_text,
)
logger.log(
f"cloud_event_id({event_id}): FULLTEXT_UPLOAD {complete_text_filename}",
severity="INFO",
)
prompt = 'Summarize:'
extracted_text_trunc = truncate_complete_text(extracted_text, _FUNCTIONS_VERTEX_EVENT_LOGGER)
summary = predict_large_language_model(
project_id=_PROJECT_ID,
model_name=_MODEL_NAME,
temperature=0.2,
max_decode_steps=1024,
top_p=0.8,
top_k=40,
content=f"{prompt}\n{extracted_text_trunc}",
location="us-central1",
)
logger.log(f"cloud_event_id({event_id}): SUMMARY_COMPLETE", severity="INFO")
output_filename = f'system-test/{name.replace(".pdf", "")}_summary.txt'
upload_to_gcs(
_OUTPUT_BUCKET,
output_filename,
summary,
)
logger.log(
f"cloud_event_id({event_id}): SUMMARY_UPLOAD {upload_to_gcs}", severity="INFO"
)
# If we have any errors, they'll be caught by the bigquery module
errors = write_summarization_to_table(
project_id=_PROJECT_ID,
dataset_id=_DATASET_ID,
table_id=_TABLE_ID,
bucket=bucket,
filename=output_filename,
complete_text=extracted_text,
complete_text_uri=complete_text_filename,
summary=summary,
summary_uri=output_filename,
timestamp=time_created,
)
if len(errors) > 0:
logger.log(
f"cloud_event_id({event_id}): DB_WRITE_ERROR: {errors}", severity="ERROR"
)
return {"errors": errors}
logger.log(f"cloud_event_id({event_id}): DB_WRITE", severity="INFO")
return {"summary": summary}