chunking/chunker_factory.py (54 lines of code) (raw):
import logging
import os
from .chunkers.doc_analysis_chunker import DocAnalysisChunker
from .chunkers.multimodal_chunker import MultimodalChunker
from .chunkers.langchain_chunker import LangChainChunker
from .chunkers.spreadsheet_chunker import SpreadsheetChunker
from .chunkers.transcription_chunker import TranscriptionChunker
from .chunkers.json_chunker import JSONChunker
from .chunkers.nl2sql_chunker import NL2SQLChunker
from tools import DocumentIntelligenceClient
from utils import get_filename_from_data, get_file_extension
class ChunkerFactory:
"""Factory class to create appropriate chunker based on file extension."""
def __init__(self):
docint_client = DocumentIntelligenceClient()
self.docint_40_api = docint_client.docint_40_api
_multimodality = os.getenv("MULTIMODAL", "false").lower()
self.multimodality = _multimodality in ["true", "1", "yes"]
def get_chunker(self, data):
"""
Get the appropriate chunker based on the file extension.
Args:
extension (str): The file extension.
data (dict): The data containing document information.
Returns:
BaseChunker: An instance of a chunker class.
"""
filename = get_filename_from_data(data)
logging.info(f"[chunker_factory][{filename}] Creating chunker")
extension = get_file_extension(filename)
if extension == 'vtt':
return TranscriptionChunker(data)
elif extension == 'json':
return JSONChunker(data)
elif extension in ('xlsx', 'xls'):
return SpreadsheetChunker(data)
elif extension in ('pdf', 'png', 'jpeg', 'jpg', 'bmp', 'tiff'):
if self.multimodality:
return MultimodalChunker(data)
else:
return DocAnalysisChunker(data)
elif extension in ('docx', 'pptx'):
if self.docint_40_api:
if self.multimodality:
return MultimodalChunker(data)
else:
return DocAnalysisChunker(data)
else:
logging.info(f"[chunker_factory][{filename}] Processing 'pptx' and 'docx' files requires Doc Intelligence 4.0.")
raise RuntimeError("Processing 'pptx' and 'docx' files requires Doc Intelligence 4.0.")
elif extension in ('nl2sql'):
return NL2SQLChunker(data)
else:
return LangChainChunker(data)
@staticmethod
def get_supported_extensions():
"""
Get a comma-separated list of supported file extensions.
Returns:
str: A comma-separated list of supported file extensions.
"""
extensions = [
'vtt',
'xlsx', 'xls',
'pdf', 'png', 'jpeg', 'jpg', 'bmp', 'tiff',
'docx', 'pptx', 'json'
]
return ', '.join(extensions)