def _get_data()

in vizseq/ipynb/fairseq_viz.py [0:0]


def _get_data(log_path_or_paths: Union[str, List[str]]):
    if isinstance(log_path_or_paths, str):
        log_path_or_paths = [log_path_or_paths]
    ids, src, ref, hypo = None, None, None, {}
    names = Counter()
    for k, log_path in enumerate(log_path_or_paths):
        assert op.isfile(log_path)
        cur_src, cur_ref, cur_hypo = {}, {}, {}
        with open(log_path) as f:
            for l in f:
                line = l.strip()
                if line.startswith('H-'):
                    _id, _, sent = line.split('\t', 2)
                    cur_hypo[_id[2:]] = sent
                elif line.startswith('T-'):
                    _id, sent = line.split('\t', 1)
                    cur_ref[_id[2:]] = sent
                elif line.startswith('S-'):
                    _id, sent = line.split('\t', 1)
                    cur_src[_id[2:]] = sent
        cur_ids = sorted(cur_src.keys())
        assert set(cur_ids) == set(cur_ref.keys()) == set(cur_hypo.keys())
        cur_src = [cur_src[i] for i in cur_ids]
        cur_ref = [cur_ref[i] for i in cur_ids]
        if k == 0:
            ids, src, ref = cur_ids, cur_src, cur_ref
        else:
            assert set(ids) == set(cur_ids) and set(src) == set(cur_src)
            assert set(ref) == set(cur_ref)
        name = op.splitext(op.basename(log_path))[0]
        names.update([name])
        if names[name] > 1:
            name += f'.{names[name]}'
        hypo[name] = [cur_hypo[i] for i in cur_ids]
    return {'0': src}, {'0': ref}, hypo