def _create_cross_document_dataset()

in yourbench/utils/dataset_engine.py [0:0]


def _create_cross_document_dataset(dataset: Dataset, stage_cfg: dict[str, object]) -> Dataset:
    """Creates a cross-document Dataset by combining multi-hop chunks from different documents.

    Args:
        dataset: A HuggingFace Dataset where each row may contain a 'multihop_chunks' list.
        stage_cfg: Stage-specific config containing 'max_combinations' and 'chunks_per_document'.

    Returns:
        A new Dataset with cross-document combinations, preserving the same schema.
    """
    max_combinations = int(stage_cfg.get("max_combinations", 100))
    chunks_per_document = int(stage_cfg.get("chunks_per_document", 1))

    if "multihop_chunks" not in dataset.column_names:
        logger.warning("Dataset is missing 'multihop_chunks'. Cross-document generation aborted.")
        return Dataset.from_list([])

    docs = []
    for idx, row in enumerate(dataset):
        multihop_chunks = row.get("multihop_chunks", [])
        if isinstance(multihop_chunks, list) and multihop_chunks:
            docs.append({
                "document_id": row.get("document_id", f"doc_{idx}"),
                "multihop_chunks": multihop_chunks,
            })

    if len(docs) < 2:
        logger.warning(f"Found only {len(docs)} document(s) with 'multihop_chunks'. Need at least 2.")
        return Dataset.from_list([])

    rng = random.Random(42)
    rng.shuffle(docs)

    cross_rows = []
    for doc1, doc2 in combinations(docs, 2):
        samp1 = rng.sample(doc1["multihop_chunks"], min(len(doc1["multihop_chunks"]), chunks_per_document))
        samp2 = rng.sample(doc2["multihop_chunks"], min(len(doc2["multihop_chunks"]), chunks_per_document))

        for chunk1 in samp1:
            for chunk2 in samp2:
                if not all(k in chunk1 for k in ("chunk_ids", "chunks_text")):
                    logger.warning(f"Skipping malformed chunk in doc {doc1['document_id']}: {chunk1}")
                    continue
                if not all(k in chunk2 for k in ("chunk_ids", "chunks_text")):
                    logger.warning(f"Skipping malformed chunk in doc {doc2['document_id']}: {chunk2}")
                    continue

                combined = {
                    "chunk_ids": chunk1["chunk_ids"] + chunk2["chunk_ids"],
                    "chunks_text": chunk1["chunks_text"] + chunk2["chunks_text"],
                }

                cross_rows.append({
                    "document_id": f"cross_{doc1['document_id']}_{doc2['document_id']}",
                    "chunks": [],
                    "multihop_chunks": [combined],
                })

                if len(cross_rows) >= max_combinations:
                    logger.debug(f"Reached max_combinations: {max_combinations}")
                    return Dataset.from_list(cross_rows)

    return Dataset.from_list(cross_rows)