hacks/genai-intro/artifacts/function/main.py (103 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import os
from itertools import islice
from typing import Iterator
import functions_framework
import vertexai
from google.cloud import bigquery
from google.cloud import storage
from google.cloud import vision
from vertexai.generative_models import GenerativeModel
PROJECT_ID=os.getenv("GCP_PROJECT_ID")
REGION=os.getenv("GCP_REGION")
STAGING_BUCKET=f"{PROJECT_ID}-staging"
BQ_DATASET="articles"
BQ_TABLE="summaries"
MODEL_NAME="gemini-2.0-flash"
vertexai.init(project=PROJECT_ID, location=REGION)
def extract_text_from_document(src_bucket: str, file_name: str, dst_bucket: str) -> str:
"""Extracts the contents of the PDF document and stores the results in a folder in GCS.
In order to extract the contents of the PDF document OCR is applied and the results,
consisting of JSON files, are stored in the destination bucket in a folder that has
the same name as the source file name.
Do not edit.
Args:
src_bucket: source bucket without the gs prefix, e.g. my-uploaded-docs-bucket
file_name: source file name, e.g. my-file.pdf
dst_bucket: destination bucket without the gs prefix, e.g. my-staging-bucket
Returns:
destination folder, name of the folder in the staging bucket where the JSON
files are stored for the PDF document
"""
src_uri = f"gs://{src_bucket}/{file_name}"
dst_uri = f"gs://{dst_bucket}/{file_name}/"
mime_type = "application/pdf"
batch_size = 2
# Perform Vision OCR
client = vision.ImageAnnotatorClient()
feature = vision.Feature(type_=vision.Feature.Type.DOCUMENT_TEXT_DETECTION)
gcs_source = vision.GcsSource(uri=src_uri)
input_config = vision.InputConfig(gcs_source=gcs_source, mime_type=mime_type)
gcs_destination = vision.GcsDestination(uri=dst_uri)
output_config = vision.OutputConfig(
gcs_destination=gcs_destination, batch_size=batch_size
)
async_request = vision.AsyncAnnotateFileRequest(
features=[feature], input_config=input_config, output_config=output_config
)
operation = client.async_batch_annotate_files(requests=[async_request])
operation.result(timeout=420)
return f"{file_name}/"
def collate_pages(bucket: str, folder: str) -> str:
"""Collates all pages, stored as JSON files in the provided bucket & folder,
parses them, extracts the relevant parts and concatenates them into a single string.
Do not edit.
Args:
bucket: bucket without the gs prefix, e.g. my-staging-bucket
folder: folder name, e.g. my-file/
Returns:
complete text of the PDF document as a single string in regular text format
"""
storage_client = storage.Client(project=PROJECT_ID)
bucket = storage_client.get_bucket(bucket)
blob_list = [blob for blob in list(bucket.list_blobs(prefix=folder))]
complete_text = ""
for output in blob_list:
json_string = output.download_as_bytes().decode("utf-8")
response = json.loads(json_string)
for page in response["responses"]:
complete_text += page["fullTextAnnotation"]["text"]
return complete_text
def get_prompt_for_title_extraction() -> str:
"""Returns a prompt for title extraction.
To be modified for Challenge 2.
Returns:
prompt for title extraction, with placeholders for substitution
"""
# TODO Challenge 2, provide the prompt, you can use {} references for substitution
# See https://www.w3schools.com/python/ref_string_format.asp
return ""
def extract_title_from_text(text: str) -> str:
"""Given the full text of the PDF document, extracts the title.
To be modified for Challenge 2.
Args:
text: full text of the PDF document
Returns:
title of the PDF document
"""
model = GenerativeModel(MODEL_NAME)
prompt_template = get_prompt_for_title_extraction()
prompt = prompt_template.format() # TODO Challenge 2, set placeholder values in format
if not prompt:
return "" # return empty title for empty prompt
if model.count_tokens(prompt).total_tokens > 2500:
raise ValueError("Too many tokens used")
response = model.generate_content(prompt)
return response.text
def pages(text: str, batch_size: int) -> Iterator[str]:
"""Yield successive n-sized chunks from text.
Do not edit.
Args:
text: full text of the PDF document
batch_size: size of the chunks
Yields:
successive n-sized chunks of text
"""
it = iter(text)
while batch := tuple(islice(it, batch_size)):
yield "".join(batch)
def get_prompt_for_page_summary_with_context() -> str:
"""Returns a prompt for the page summary with context.
To be modified for Challenge 3.
Returns:
prompt for page summary with context, with placeholders for substitution
"""
# TODO Challenge 3, provide the prompt, you can use {} references for substitution
# See https://www.w3schools.com/python/ref_string_format.asp
return ""
def extract_summary_from_text(text: str) -> str:
"""Given the full text of the PDF document, extracts the summary.
To be modified for Challenge 3.
Args:
text: full text of the PDF document
Returns:
summary of the PDF document
"""
model = GenerativeModel(MODEL_NAME)
rolling_prompt_template = get_prompt_for_page_summary_with_context()
if not rolling_prompt_template:
return "" # return empty summary for empty prompts
summary = ""
for page in pages(text, 16000):
prompt = rolling_prompt_template.format() # TODO Challenge 3, set placeholder values in format
summary = model.generate_content(prompt).text
return summary
def store_results_in_bq(dataset: str, table: str, columns: dict[str, str]) -> bool:
"""Stores the results of title extraction and summary generation in BigQuery.
Do not edit.
Args:
dataset: the name of the BigQuery dataset
table: the name of the BigQuery table where the results will be stored
columns: a dictionary of columns and their values
Returns:
True if successful, False otherwise (if there were any errors)
"""
client = bigquery.Client(project=PROJECT_ID)
table_uri = f"{dataset}.{table}"
rows_to_insert = [columns]
errors = client.insert_rows_json(
table_uri, rows_to_insert, row_ids=bigquery.AutoRowIDs.GENERATE_UUID
)
if errors:
print("Errors while storing data in BQ:", errors)
return not errors
@functions_framework.cloud_event
def on_document_added(event):
"""Triggered from a message on a Cloud Pub/Sub topic.
Do not edit until Challenge 4.
Args:
event: event payload
context: metadata for the event.
"""
pubsub_message = json.loads(base64.b64decode(event.data["message"]["data"]).decode("utf-8"))
src_bucket = pubsub_message["bucket"]
src_fname = pubsub_message["name"]
print(f"Processing file: {src_fname}")
if pubsub_message["contentType"] != "application/pdf":
raise ValueError("Only PDF files are supported, aborting")
dst_bucket = STAGING_BUCKET
dst_folder = extract_text_from_document(src_bucket, src_fname, dst_bucket)
print("Completed the text extraction")
complete_text = collate_pages(dst_bucket, dst_folder)
print(f"Completed collation, #characters: {len(complete_text)}")
title = extract_title_from_text(complete_text)
print(f"Title: {title}")
summary = extract_summary_from_text(complete_text)
print(f"Summary: {summary}")
# TODO Challenge 4, uncomment the next two lines
# columns = {"uri": f"gs://{src_bucket}/{src_fname}", "name": src_fname, "title": title, "summary": summary}
# store_results_in_bq(dataset=BQ_DATASET, table=BQ_TABLE, columns=columns)