scripts/process_zip/process_zip.py (109 lines of code) (raw):
import uuid
import zipfile
import os
import json
import argparse
import asyncio
from loguru import logger
from models.models import Document, DocumentMetadata, Source
from datastore.datastore import DataStore
from datastore.factory import get_datastore
from services.extract_metadata import extract_metadata_from_document
from services.file import extract_text_from_filepath
from services.pii_detection import screen_text_for_pii
DOCUMENT_UPSERT_BATCH_SIZE = 50
async def process_file_dump(
filepath: str,
datastore: DataStore,
custom_metadata: dict,
screen_for_pii: bool,
extract_metadata: bool,
):
# create a ZipFile object and extract all the files into a directory named 'dump'
with zipfile.ZipFile(filepath) as zip_file:
zip_file.extractall("dump")
documents = []
skipped_files = []
# use os.walk to traverse the dump directory and its subdirectories
for root, dirs, files in os.walk("dump"):
for filename in files:
if len(documents) % 20 == 0:
logger.info(f"Processed {len(documents)} documents")
filepath = os.path.join(root, filename)
try:
extracted_text = extract_text_from_filepath(filepath)
logger.info(f"extracted_text from {filepath}")
# create a metadata object with the source and source_id fields
metadata = DocumentMetadata(
source=Source.file,
source_id=filename,
)
# update metadata with custom values
for key, value in custom_metadata.items():
if hasattr(metadata, key):
setattr(metadata, key, value)
# screen for pii if requested
if screen_for_pii:
pii_detected = screen_text_for_pii(extracted_text)
# if pii detected, print a warning and skip the document
if pii_detected:
logger.info("PII detected in document, skipping")
skipped_files.append(
filepath
) # add the skipped file to the list
continue
# extract metadata if requested
if extract_metadata:
# extract metadata from the document text
extracted_metadata = extract_metadata_from_document(
f"Text: {extracted_text}; Metadata: {str(metadata)}"
)
# get a Metadata object from the extracted metadata
metadata = DocumentMetadata(**extracted_metadata)
# create a document object with a random id, text and metadata
document = Document(
id=str(uuid.uuid4()),
text=extracted_text,
metadata=metadata,
)
documents.append(document)
except Exception as e:
# log the error and continue with the next file
logger.error(f"Error processing {filepath}: {e}")
skipped_files.append(filepath) # add the skipped file to the list
# do this in batches, the upsert method already batches documents but this allows
# us to add more descriptive logging
for i in range(0, len(documents), DOCUMENT_UPSERT_BATCH_SIZE):
# Get the text of the chunks in the current batch
batch_documents = [doc for doc in documents[i : i + DOCUMENT_UPSERT_BATCH_SIZE]]
logger.info(f"Upserting batch of {len(batch_documents)} documents, batch {i}")
logger.info("documents: ", documents)
await datastore.upsert(batch_documents)
# delete all files in the dump directory
for root, dirs, files in os.walk("dump", topdown=False):
for filename in files:
filepath = os.path.join(root, filename)
os.remove(filepath)
for dirname in dirs:
dirpath = os.path.join(root, dirname)
os.rmdir(dirpath)
# delete the dump directory
os.rmdir("dump")
# print the skipped files
logger.info(f"Skipped {len(skipped_files)} files due to errors or PII detection")
for file in skipped_files:
logger.info(file)
async def main():
# parse the command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--filepath", required=True, help="The path to the file dump")
parser.add_argument(
"--custom_metadata",
default="{}",
help="A JSON string of key-value pairs to update the metadata of the documents",
)
parser.add_argument(
"--screen_for_pii",
default=False,
type=bool,
help="A boolean flag to indicate whether to try the PII detection function (using a language model)",
)
parser.add_argument(
"--extract_metadata",
default=False,
type=bool,
help="A boolean flag to indicate whether to try to extract metadata from the document (using a language model)",
)
args = parser.parse_args()
# get the arguments
filepath = args.filepath
custom_metadata = json.loads(args.custom_metadata)
screen_for_pii = args.screen_for_pii
extract_metadata = args.extract_metadata
# initialize the db instance once as a global variable
datastore = await get_datastore()
# process the file dump
await process_file_dump(
filepath, datastore, custom_metadata, screen_for_pii, extract_metadata
)
if __name__ == "__main__":
asyncio.run(main())