in yourbench/utils/dataset_engine.py [0:0]
def _create_cross_document_dataset(dataset: Dataset, stage_cfg: dict[str, object]) -> Dataset:
"""Creates a cross-document Dataset by combining multi-hop chunks from different documents.
Args:
dataset: A HuggingFace Dataset where each row may contain a 'multihop_chunks' list.
stage_cfg: Stage-specific config containing 'max_combinations' and 'chunks_per_document'.
Returns:
A new Dataset with cross-document combinations, preserving the same schema.
"""
max_combinations = int(stage_cfg.get("max_combinations", 100))
chunks_per_document = int(stage_cfg.get("chunks_per_document", 1))
if "multihop_chunks" not in dataset.column_names:
logger.warning("Dataset is missing 'multihop_chunks'. Cross-document generation aborted.")
return Dataset.from_list([])
docs = []
for idx, row in enumerate(dataset):
multihop_chunks = row.get("multihop_chunks", [])
if isinstance(multihop_chunks, list) and multihop_chunks:
docs.append({
"document_id": row.get("document_id", f"doc_{idx}"),
"multihop_chunks": multihop_chunks,
})
if len(docs) < 2:
logger.warning(f"Found only {len(docs)} document(s) with 'multihop_chunks'. Need at least 2.")
return Dataset.from_list([])
rng = random.Random(42)
rng.shuffle(docs)
cross_rows = []
for doc1, doc2 in combinations(docs, 2):
samp1 = rng.sample(doc1["multihop_chunks"], min(len(doc1["multihop_chunks"]), chunks_per_document))
samp2 = rng.sample(doc2["multihop_chunks"], min(len(doc2["multihop_chunks"]), chunks_per_document))
for chunk1 in samp1:
for chunk2 in samp2:
if not all(k in chunk1 for k in ("chunk_ids", "chunks_text")):
logger.warning(f"Skipping malformed chunk in doc {doc1['document_id']}: {chunk1}")
continue
if not all(k in chunk2 for k in ("chunk_ids", "chunks_text")):
logger.warning(f"Skipping malformed chunk in doc {doc2['document_id']}: {chunk2}")
continue
combined = {
"chunk_ids": chunk1["chunk_ids"] + chunk2["chunk_ids"],
"chunks_text": chunk1["chunks_text"] + chunk2["chunks_text"],
}
cross_rows.append({
"document_id": f"cross_{doc1['document_id']}_{doc2['document_id']}",
"chunks": [],
"multihop_chunks": [combined],
})
if len(cross_rows) >= max_combinations:
logger.debug(f"Reached max_combinations: {max_combinations}")
return Dataset.from_list(cross_rows)
return Dataset.from_list(cross_rows)