in vizseq/ipynb/fairseq_viz.py [0:0]
def _get_data(log_path_or_paths: Union[str, List[str]]):
if isinstance(log_path_or_paths, str):
log_path_or_paths = [log_path_or_paths]
ids, src, ref, hypo = None, None, None, {}
names = Counter()
for k, log_path in enumerate(log_path_or_paths):
assert op.isfile(log_path)
cur_src, cur_ref, cur_hypo = {}, {}, {}
with open(log_path) as f:
for l in f:
line = l.strip()
if line.startswith('H-'):
_id, _, sent = line.split('\t', 2)
cur_hypo[_id[2:]] = sent
elif line.startswith('T-'):
_id, sent = line.split('\t', 1)
cur_ref[_id[2:]] = sent
elif line.startswith('S-'):
_id, sent = line.split('\t', 1)
cur_src[_id[2:]] = sent
cur_ids = sorted(cur_src.keys())
assert set(cur_ids) == set(cur_ref.keys()) == set(cur_hypo.keys())
cur_src = [cur_src[i] for i in cur_ids]
cur_ref = [cur_ref[i] for i in cur_ids]
if k == 0:
ids, src, ref = cur_ids, cur_src, cur_ref
else:
assert set(ids) == set(cur_ids) and set(src) == set(cur_src)
assert set(ref) == set(cur_ref)
name = op.splitext(op.basename(log_path))[0]
names.update([name])
if names[name] > 1:
name += f'.{names[name]}'
hypo[name] = [cur_hypo[i] for i in cur_ids]
return {'0': src}, {'0': ref}, hypo