demo-python/code/indexers/document-intelligence-custom-skill/api/functions/function_app.py (253 lines of code) (raw):
import azure.functions as func
import json
import logging
import os
from azure.ai.documentintelligence.models import AnalyzeDocumentRequest
from azure.ai.documentintelligence.aio import DocumentIntelligenceClient
from azure.identity.aio import DefaultAzureCredential
import base64
import io
from azure.core.exceptions import HttpResponseError
from typing import Any
from langchain_text_splitters.markdown import MarkdownHeaderTextSplitter
from langchain_text_splitters.character import RecursiveCharacterTextSplitter
app = func.FunctionApp()
@app.function_name(name="markdownsplit")
@app.route(route="markdownsplit")
async def SplitMarkdownDocument(req: func.HttpRequest) -> func.HttpResponse:
input = {}
result_content = []
try:
req_body = req.get_json()
# Read input values
# Either file_data or metadata_storage_path and metadata_storage_sas_token
if "values" in req_body:
for value in req_body["values"]:
record_id = value["recordId"]
chunk_size = 512
if "chunkSize" in value["data"]:
try:
chunk_size = int(value["data"]["chunkSize"])
except Exception as e:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "'chunkSize' must be an int"
}
]
}
)
continue
chunk_overlap = 128
if "chunkOverlap" in value["data"]:
try:
chunk_overlap = int(value["data"]["chunkOverlap"])
except Exception as e:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "'chunkOverlap' must be an int"
}
]
}
)
continue
encoder_model_name = "text-embedding-3-large"
if "encoderModelName" in value["data"]:
encoder_model_name = value["data"]["encoderModelName"]
if encoder_model_name not in ["text-embedding-ada-002", "text-embedding-3-large", "text-embedding-3-small"]:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Unknown encoder model {encoder_model_name}"
}
]
}
)
continue
if "content" in value["data"]:
input[record_id] = { "content": value["data"]["content"], "chunkSize": chunk_size, "chunkOverlap": chunk_overlap, "encoderModelName": encoder_model_name }
else:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "Expected 'content'"
}
]
}
)
except Exception as e:
logging.exception("Could not get input data request body")
return func.HttpResponse(
"Invalid input",
status_code=400
)
# Split the document into chunks based on markdown headers.
# Max chunking goes 8 headers deep.
headers_to_split_on = [
("#", "1"),
("##", "2"),
("###", "3"),
("####", "4"),
("#####", "5"),
("######", "6"),
("#######", "7"),
("########", "8")
]
text_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, return_each_line=False, strip_headers=False)
for record_id, data in input.items():
encoder_model_name = data["encoderModelName"]
character_splitter: RecursiveCharacterTextSplitter
try:
if encoder_model_name:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(model_name=encoder_model_name, chunk_size = data["chunkSize"], chunk_overlap = data["chunkOverlap"])
else:
character_splitter = RecursiveCharacterTextSplitter(chunk_size = data["chunkSize"], chunk_overlap = data["chunkOverlap"])
except Exception as e:
logging.exception("Failed to load text splitter")
result_content.append({
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Failed to split text: {e}"
}
]
})
continue
# Split markdown content into chunks based on headers
md_chunks = text_splitter.split_text(data["content"])
# Further split the markdown chunks into the desired
char_chunks = character_splitter.split_documents(md_chunks)
# Return chunk content and headers
chunks = [{ "content": document.page_content, "headers": [document.metadata[header] for header in sorted(document.metadata.keys())] } for document in char_chunks]
result_content.append({ "recordId": record_id, "data": { "chunks": chunks } })
response = { "values": result_content }
return func.HttpResponse(body=json.dumps(response), mimetype="application/json", status_code=200)
@app.function_name(name="read")
@app.route(route="read")
async def ReadDocument(req: func.HttpRequest) -> func.HttpResponse:
input = {}
result_content = []
try:
req_body = req.get_json()
# Read input values
# Either file_data or metadata_storage_path and metadata_storage_sas_token
if "values" in req_body:
for value in req_body["values"]:
record_id = value["recordId"]
# Check if using markdown or text mode
if "mode" in value["data"]:
mode = value["data"]["mode"]
if mode not in ["markdown", "text"]:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "'mode' must be either 'text' or 'markdown'"
}
]
}
)
continue
else:
mode = "text"
if "file_data" in value["data"]:
input[record_id] = { "type": "file", "file_data": value["data"]["file_data"], "mode": mode }
elif "metadata_storage_path" in value["data"] and "metadata_storage_sas_token" in value["data"]:
input[record_id] = { "type": "sas", "sas_uri": f"{value['data']['metadata_storage_path']}{value['data']['metadata_storage_sas_token']}", "mode": mode }
else:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "Expected either 'file_data' or 'metadata_storage_path' and 'metadata_storage_sas_token'"
}
]
}
)
except Exception as e:
logging.exception("Could not get input data request body")
return func.HttpResponse(
"Invalid input",
status_code=400
)
async with DocumentIntelligenceClient(endpoint=os.environ["AZURE_DOCUMENTINTELLIGENCE_ENDPOINT"], credential=DefaultAzureCredential()) as client:
for record_id, data in input.items():
if "file_data" in data:
result_content.append(await process_file(client, record_id, data["file_data"], data["mode"])) # type: ignore
else:
result_content.append(await process_sas_uri(client, record_id, data["sas_uri"], data["mode"])) # type: ignore
response = { "values": result_content }
return func.HttpResponse(body=json.dumps(response), mimetype="application/json", status_code=200)
async def process_file(client: DocumentIntelligenceClient, record_id: str, file: Any, mode: str):
if not isinstance(file, dict) or \
file["$type"] != "file" or \
"data" not in file:
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": "file_data is not in correct format"
}
]
}
try:
file_bytes = base64.b64decode(file["data"])
file_data = io.BytesIO(file_bytes)
except:
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": "Failed to decode file content"
}
]
}
try:
poller = await client.begin_analyze_document(
"prebuilt-layout", analyze_request=file_data, content_type="application/octet-stream", output_content_format=mode, features=["ocrHighResolution"]
)
result = await poller.result()
return {
"recordId": record_id, "data": { "content": result.content }
}
except HttpResponseError as e:
logging.exception("Failed to read document")
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Failed to process file content: {e.message}"
}
]
}
async def process_sas_uri(client: DocumentIntelligenceClient, record_id: str, sas_uri: str, mode: str):
try:
poller = await client.begin_analyze_document(
"prebuilt-layout", analyze_request=AnalyzeDocumentRequest(url_source=sas_uri), output_content_format=mode, features=["ocrHighResolution"]
)
result = await poller.result()
return {
"recordId": record_id, "data": { "content": result.content }
}
except HttpResponseError as e:
logging.exception("Failed to read document")
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Failed to process file content: {e.message}"
}
]
}