in datasets.py [0:0]
def get_document_predictions(chunk_data: List[List[tuple]]) -> List[List[Tuple[int, int]]]:
"""
Aggregate predictions for each chunk into document-level predictions.
"""
all_edges = set(x for l in chunk_data for x in l)
graph = nx.Graph()
graph.add_edges_from(all_edges)
processed_groups = []
for component in nx.connected_components(graph):
processed_group = []
for start, end in sorted(component, key=lambda x: (x[0], -x[1])):
# add this entity if it does not overlap with the previous one
if len(processed_group) == 0 or start >= processed_group[-1][1]:
processed_group.append((start, end))
processed_groups.append(processed_group)
return [[(start, end) for start, end in group] for group in processed_groups]