chunking/chunkers/json_chunker.py (98 lines of code) (raw):
# chunking/chunkers/json_chunker.py
import json
import logging
from .base_chunker import BaseChunker
class JSONChunker(BaseChunker):
"""
JSONChunker is a custom chunker for well-formatted JSON files.
It parses the JSON and recursively partitions the structure into smaller valid JSON chunks
that do not exceed the maximum token limit.
"""
def __init__(self, data, max_chunk_size=None, token_overlap=None, minimum_chunk_size=None):
super().__init__(data)
import os
self.max_chunk_size = int(max_chunk_size or os.getenv("NUM_TOKENS", "2048"))
self.token_overlap = int(token_overlap or os.getenv("TOKEN_OVERLAP", "100"))
self.minimum_chunk_size = int(minimum_chunk_size or os.getenv("MIN_CHUNK_SIZE", "100"))
def get_chunks(self):
"""
Splits the JSON content into chunks while ensuring each chunk is valid JSON.
The method:
1. Decodes document bytes to text.
2. Parses the JSON.
3. Uses a recursive partitioning algorithm to split the parsed JSON into valid pieces
whose pretty-printed form is within the token limit.
4. Creates chunk dictionaries from the resulting pieces.
"""
if not self.document_bytes:
logging.error(f"[json_chunker][{self.filename}] No document bytes provided.")
return []
text = self.decode_to_utf8(self.document_bytes)
try:
parsed_json = json.loads(text)
except json.JSONDecodeError as e:
logging.error(f"[json_chunker][{self.filename}] Error parsing JSON: {e}")
return []
# Recursively partition the parsed JSON
partitioned = self._recursive_chunk_json(parsed_json)
# Pretty-print each partition and filter by token count
chunk_texts = []
for part in partitioned:
dumped = json.dumps(part, indent=2, ensure_ascii=False)
token_count = self.token_estimator.estimate_tokens(dumped)
if token_count >= self.minimum_chunk_size:
chunk_texts.append(dumped)
chunk_dicts = []
chunk_id = 0
for chunk_text in chunk_texts:
token_count = self.token_estimator.estimate_tokens(chunk_text)
if token_count > self.max_chunk_size:
logging.warning(
f"[json_chunker][{self.filename}] A chunk still exceeds max tokens ({token_count} > {self.max_chunk_size})."
" This may happen if a single element is very large."
)
# Optionally, you might decide to leave such chunks as is,
# or further process them with a string splitter.
chunk_id += 1
chunk_dict = self._create_chunk(chunk_id, chunk_text)
chunk_dicts.append(chunk_dict)
logging.info(f"[json_chunker][{self.filename}] Created {len(chunk_dicts)} chunk(s).")
return chunk_dicts
def _recursive_chunk_json(self, obj):
"""
Recursively partition a JSON object (list or dict) so that each partition's
pretty-printed string does not exceed self.max_chunk_size tokens.
Returns a list of JSON-compatible Python objects.
"""
def token_count_of(data):
dumped = json.dumps(data, indent=2, ensure_ascii=False)
return self.token_estimator.estimate_tokens(dumped)
# If obj is a list, partition its items.
if isinstance(obj, list):
partitions = []
current = []
for item in obj:
candidate = current + [item]
if token_count_of(candidate) <= self.max_chunk_size:
current.append(item)
else:
if current:
# Recursively check the current partition in case a single element is too large.
if token_count_of(current) > self.max_chunk_size and len(current) == 1:
partitions.extend(self._recursive_chunk_json(current[0]))
else:
partitions.append(current)
# If the item itself is too big, try to partition it further.
if token_count_of([item]) > self.max_chunk_size and isinstance(item, (list, dict)):
partitions.extend(self._recursive_chunk_json(item))
else:
current = [item]
if current:
partitions.append(current)
return partitions
# If obj is a dict, partition its key-value pairs.
elif isinstance(obj, dict):
partitions = []
current = {}
for key, value in obj.items():
candidate = current.copy()
candidate[key] = value
if token_count_of(candidate) <= self.max_chunk_size:
current[key] = value
else:
if current:
# If a single key-value pair is too large, try to partition its value.
if token_count_of(current) > self.max_chunk_size and len(current) == 1:
# current has one key; try partitioning its value if possible.
k = list(current.keys())[0]
v = current[k]
if isinstance(v, (list, dict)):
subparts = self._recursive_chunk_json(v)
for sub in subparts:
partitions.append({k: sub})
else:
partitions.append(current)
else:
partitions.append(current)
# Try partitioning the new key-value pair if it's too large.
single = {key: value}
if token_count_of(single) > self.max_chunk_size and isinstance(value, (list, dict)):
subparts = self._recursive_chunk_json(value)
for sub in subparts:
partitions.append({key: sub})
else:
current = {key: value}
if current:
partitions.append(current)
return partitions
# For primitives, just return them as a single partition.
else:
return [obj]