def _batch()

in vizseq/scorers/__init__.py [0:0]


    def _batch(hypo: List[str], ref: List[List[str]], n_batches: int):
        n_samples = len(hypo)
        assert all(len(r) == n_samples for r in ref)
        hypo_and_ref = [hypo] + ref
        merged = list(zip(*hypo_and_ref))
        batched = _batch(merged, n_batches=n_batches)
        for b in batched:
            part_hypo, *part_ref = zip(*b)
            yield list(part_hypo), [list(r) for r in part_ref]