in vizseq/scorers/__init__.py [0:0]
def _batch(hypo: List[str], ref: List[List[str]], n_batches: int):
n_samples = len(hypo)
assert all(len(r) == n_samples for r in ref)
hypo_and_ref = [hypo] + ref
merged = list(zip(*hypo_and_ref))
batched = _batch(merged, n_batches=n_batches)
for b in batched:
part_hypo, *part_ref = zip(*b)
yield list(part_hypo), [list(r) for r in part_ref]