webhook/main.py (256 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import json
import logging
import multiprocessing
import os
from collections.abc import Iterable
from datetime import datetime
import functions_framework
from cloudevents.http import CloudEvent
from google import genai # type: ignore
from google.genai.types import GenerateContentConfig # type: ignore
from google.api_core.client_options import ClientOptions
from google.api_core.retry import Retry
from google.cloud import aiplatform
from google.cloud import documentai
from google.cloud import firestore # type: ignore
from google.cloud import storage # type: ignore
from google.cloud.aiplatform_v1.types import IndexDatapoint
DOCAI_LOCATION = os.environ.get("DOCAI_LOCATION", "us")
@functions_framework.cloud_event
def on_cloud_event(event: CloudEvent) -> None:
"""Process a new document from an Eventarc event.
Args:
event: CloudEvent object.
"""
try:
process_document(
event_id=event.data["id"],
input_bucket=event.data["bucket"],
filename=event.data["name"],
mime_type=event.data["contentType"],
time_uploaded=datetime.fromisoformat(event.data["timeCreated"]),
project=os.environ["PROJECT_ID"],
location=os.environ["LOCATION"],
docai_processor_id=os.environ["DOCAI_PROCESSOR"],
database=os.environ["DATABASE"],
output_bucket=os.environ["OUTPUT_BUCKET"],
index_id=os.environ["INDEX_ID"],
)
except Exception as e:
logging.exception(e, stack_info=True)
def process_document(
event_id: str,
input_bucket: str,
filename: str,
mime_type: str,
time_uploaded: datetime,
project: str,
location: str,
docai_processor_id: str,
database: str,
output_bucket: str,
index_id: str,
) -> None:
"""Process a new document.
Args:
event_id: ID of the event.
input_bucket: Name of the input bucket.
filename: Name of the input file.
mime_type: MIME type of the input file.
time_uploaded: Time the input file was uploaded.
project: Google Cloud project ID.
location: Google Cloud location.
docai_processor_id: ID of the Document AI processor.
database: Name of the Firestore database.
output_bucket: Name of the output bucket.
index_id: ID of the Vector Search index.
"""
aiplatform.init(project=project, location=location)
db = firestore.Client(project=project, database=database)
doc = db.document("documents", filename.replace("/", "-"))
event_entry = {
"event_id": event_id,
"bucket": input_bucket,
"filename": filename,
"mime_type": mime_type,
"time_uploaded": time_uploaded,
}
if (entry := doc.get().to_dict() or {}) and entry.get("event_id") == event_id:
# We've already processed this event, this is probably an event retry.
return
if doc.get().exists:
doc.update(event_entry)
else:
doc.create(event_entry)
input_gcs_uri = f"gs://{input_bucket}/{filename}"
print(f"📖 {event_id}: Getting document text")
pages = list(get_document_text(input_gcs_uri, mime_type, docai_processor_id, output_bucket))
doc.update({"pages": pages})
print(f"🗂️ {event_id}: Indexing pages into Vector Search")
embeddings = get_pages_embeddings(project, location, pages)
index_pages(index_id, filename, embeddings)
print(f"🔍 {event_id}: Generating Q&As with model ({len(pages)} pages)")
with multiprocessing.Pool(len(pages)) as pool:
event_pages = [
{
"project": project,
"location": location,
"filename": filename,
"page_number": i,
"text": page,
}
for i, page in enumerate(pages)
]
page_entries = pool.map(process_page, event_pages)
document_entries = list(itertools.chain.from_iterable(page_entries))
print(f"🗃️ {event_id}: Saving Q&As to Firestore ({len(document_entries)} entries)")
for entry in document_entries:
doc = db.document("dataset", entry["question"].replace("/", " "))
if doc.get().exists:
doc.update(entry)
else:
doc.create(entry)
print(f"📝 {event_id}: Writing tuning dataset: gs://{output_bucket}/dataset.jsonl")
dataset_size = write_tuning_dataset(db, output_bucket)
print(f"✅ {event_id}: Done! {dataset_size=}")
def process_page(event_page: dict) -> list[dict[str, str]]:
"""Generate questions and answers for a single page of a document.
Args:
event_page: Dictionary containing the event pages information.
Returns: Dictionaries containing the questions and answers.
"""
project = event_page["project"]
location = event_page["location"]
filename = event_page["filename"]
page_number = event_page["page_number"]
text = event_page["text"]
entries = generate_questions(project, location, text)
try:
return [
{
"question": entry["question"],
"answer": entry["answer"],
"filename": filename,
"page_number": page_number,
}
for entry in entries
]
except KeyError:
logging.exception(f"Q&A generation failed: {entries}", stack_info=True)
return []
def get_document_text(
input_file: str,
mime_type: str,
processor_id: str,
temp_bucket: str,
) -> Iterable[str]:
"""Perform Optical Character Recognition (OCR) with Document AI on a Cloud Storage files.
For more information, see:
https://cloud.google.com/document-ai/docs/process-documents-ocr
Args:
input_file: GCS URI of the document file.
mime_type: MIME type of the document file.
processor_id: ID of the Document AI processor.
temp_bucket: GCS bucket to store Document AI temporary files.
Returns: A list of the text in each page of the document.
"""
# You must set the `api_endpoint` if you use a location other than "us".
documentai_client = documentai.DocumentProcessorServiceClient(
client_options=ClientOptions(api_endpoint=f"{DOCAI_LOCATION}-documentai.googleapis.com")
)
# We're using batch_process_documents instead of process_document because
# process_document has a quota limit of 15 pages per document, while
# batch_process_documents has a quota limit of 500 pages per request.
# https://cloud.google.com/document-ai/quotas#general_processors
operation = documentai_client.batch_process_documents(
request=documentai.BatchProcessRequest(
name=processor_id,
input_documents=documentai.BatchDocumentsInputConfig(
gcs_documents=documentai.GcsDocuments(
documents=[
documentai.GcsDocument(
gcs_uri=input_file,
mime_type=mime_type,
),
],
),
),
document_output_config=documentai.DocumentOutputConfig(
gcs_output_config=documentai.DocumentOutputConfig.GcsOutputConfig(
gcs_uri=f"gs://{temp_bucket}/ocr/{input_file.split('gs://')[-1]}",
),
),
),
)
operation.result()
# Read the results of the Document AI operation from Cloud Storage.
storage_client = storage.Client()
metadata = documentai.BatchProcessMetadata(operation.metadata)
output_gcs_path = metadata.individual_process_statuses[0].output_gcs_destination
(output_bucket, output_prefix) = output_gcs_path.removeprefix("gs://").split("/", 1)
for blob in storage_client.list_blobs(output_bucket, prefix=output_prefix):
blob_contents = blob.download_as_bytes()
document = documentai.Document.from_json(blob_contents, ignore_unknown_fields=True)
for page in document.pages:
segments = [
(segment.start_index, segment.end_index)
for segment in page.layout.text_anchor.text_segments
]
yield "\n".join([document.text[start:end] for (start, end) in segments])
def get_pages_embeddings(
project: str,
location: str,
pages: Iterable[str],
) -> Iterable[list[float]]:
"""Get embeddings for a list of pages.
For more information, see:
https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings
Args:
project: Google Cloud project ID.
location: Google Cloud location.
pages: A list of the text in each page of the document.
"""
genai_client = genai.Client(vertexai=True, project=project, location=location)
max_input_texts = 5
for batch in itertools.batched(pages, max_input_texts):
response = genai_client.models.embed_content(
model="text-embedding-005",
contents=batch,
)
embeddings = response.embeddings or []
for embedding in embeddings:
yield embedding.values or []
def index_pages(
index_id: str,
filename: str,
embeddings: Iterable[list[float]],
) -> None:
"""Index pages into Vertex AI's Vector Search.
Args:
index_id: ID of the Vector Search index.
filename: Name of the input file.
embeddings: A list of embeddings for each page of the document.
"""
points = [
IndexDatapoint(
datapoint_id=f"{filename}:{page_number}",
feature_vector=embedding,
)
for page_number, embedding in enumerate(embeddings)
]
index = aiplatform.MatchingEngineIndex(index_id)
index.remove_datapoints(["null"])
index.upsert_datapoints(points).wait()
@Retry(lambda _: True) # any exception since models are non-deterministic.
def generate_questions(project: str, location: str, text: str) -> list[dict[str, str]]:
"""Extract questions & answers using a large language model (LLM).
For more information, see:
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
Args:
project: Google Cloud project ID.
location: Google Cloud location.
text: the text to generate questions and answers for
Returns: A list of (question, answer) tuples
"""
# Ask the model to generate the questions and answers.
genai_client = genai.Client(vertexai=True, project=project, location=location)
response = genai_client.models.generate_content(
model="gemini-2.0-flash",
contents="List 20 self-contained questions and answers that can be answered from the text.",
config=GenerateContentConfig(
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/system-instructions
system_instruction=[
"Use simple language and words that are easy to understand.",
"Avoid technical terms in the answers.",
],
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/control-generated-output
response_mime_type="application/json",
response_schema={
"type": "ARRAY",
"items": {
"type": "OBJECT",
"properties": {
"question": {"type": "STRING"},
"answer": {"type": "STRING"},
},
"required": ["question", "answer"],
},
},
),
)
text = response.text or ""
# The response is sometimes in code blocks, so we need to extract it.
code_block_start = text.find("```")
if code_block_start == -1:
code_block = text
else:
code_block = "\n".join(text[code_block_start:].splitlines()[1:-1])
# Parse the response as JSON.
try:
return json.loads(code_block)
except json.decoder.JSONDecodeError:
logging.debug(f"Failed to parse response:\n{response}")
raise
def write_tuning_dataset(db: firestore.Client, output_bucket: str) -> int:
"""Write the tuning dataset to Cloud Storage.
For more information on the tuning dataset file format:
https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about
Args:
db: Firestore client.
output_bucket: Name of the output bucket.
Returns: The number of entries in the tuning dataset.
"""
storage_client = storage.Client()
documents = [doc.to_dict() or {} for doc in db.collection("documents").stream()]
doc_pages = {doc["filename"]: doc["pages"] for doc in documents}
dataset_size = 0
with storage_client.get_bucket(output_bucket).blob("dataset.jsonl").open("w") as f:
for doc in db.collection("dataset").stream():
entry = doc.to_dict() or {}
context = doc_pages[entry["filename"]][entry["page_number"]]
row = {
"systemInstruction": {
"parts": [{"text": "Answer the question based on the following text"}],
},
"contents": [
{
"role": "user",
"parts": [
{"text": f"Text: {context}"},
{"text": entry["question"]},
],
},
{
"role": "model",
"parts": [{"text": entry["answer"]}],
},
],
}
f.write(f"{json.dumps(row)}\n")
dataset_size += 1
return dataset_size