in align/data.py [0:0]
def __init__(self, path, langs, split='test'):
if langs == 'de-en':
src_sents = [x.strip() for x in open(os.path.join(path, langs, 'de'), encoding='iso-8859-1').readlines()][:-1]
trg_sents = [x.strip() for x in open(os.path.join(path, langs, 'en'), encoding='iso-8859-1').readlines()][:-1]
self.ground_truth = self.load_std_file(os.path.join(path, langs, 'alignmentDeEn.talp'))[:-1]
elif langs == 'ro-en' or langs == 'en-fr':
src_id2s = dict()
trg_id2s = dict()
for fpair in open(os.path.join(path, langs, split, f'FilePairs.{split}')):
sf, tf = fpair.strip().split()
for line in open(os.path.join(path, langs, split, sf), encoding='iso-8859-1'):
matching = regex.match(r'<s snum=([0-9]*)>(.*)</s>', line.strip())
assert matching is not None
idx = matching.group(1)
sent = matching.group(2).strip()
src_id2s[idx] = sent
for line in open(os.path.join(path, langs, split, tf), encoding='iso-8859-1'):
matching = regex.match(r'<s snum=([0-9]*)>(.*)</s>', line.strip())
assert matching is not None
idx = matching.group(1)
sent = matching.group(2).strip()
trg_id2s[idx] = sent
src_sents = [src_id2s[key] for key in sorted(src_id2s.keys())]
trg_sents = [trg_id2s[key] for key in sorted(trg_id2s.keys())]
snum2idx = dict([(key, i) for i, key in enumerate(sorted(trg_id2s.keys()))])
assert len(src_id2s) == len(trg_id2s)
ground_truth = [list() for _ in src_id2s]
raw_gt = open(os.path.join(path, langs, split, f'{split}.wa.nonullalign')).readlines()
for line in raw_gt:
sid, s, t, sure = line.strip().split()
idx = snum2idx[sid]
if sure == 'S':
align = '-'.join([s, t])
else:
assert sure == 'P'
align = 'p'.join([s, t])
ground_truth[idx].append(align)
for i, item in enumerate(ground_truth):
ground_truth[i] = ' '.join(item)
self.ground_truth = ground_truth
elif langs == 'en-hi':
src_id2s = dict()
trg_id2s = dict()
sf = f'{split}.e'
tf = f'{split}.h'
for line in open(os.path.join(path, langs, split, sf), encoding='us-ascii'):
matching = regex.match(r'<s snum=([0-9]*)>(.*)</s>', line.strip())
assert matching is not None
idx = matching.group(1)
sent = matching.group(2).strip()
src_id2s[idx] = sent
for line in open(os.path.join(path, langs, split, tf), encoding='utf-8'):
matching = regex.match(r'<s snum=([0-9]*)>(.*)</s>', line.strip())
assert matching is not None
idx = matching.group(1)
sent = matching.group(2).strip()
trg_id2s[idx] = sent
src_sents = [src_id2s[key] for key in sorted(src_id2s.keys())]
trg_sents = [trg_id2s[key] for key in sorted(trg_id2s.keys())]
snum2idx = dict([(key, i) for i, key in enumerate(sorted(trg_id2s.keys()))])
assert len(src_id2s) == len(trg_id2s)
ground_truth = [list() for _ in src_id2s]
raw_gt = open(os.path.join(path, langs, split, f'{split}.wa.nonullalign')).readlines()
for line in raw_gt:
sid, s, t = line.strip().split()
idx = snum2idx[sid]
align = '-'.join([s, t])
ground_truth[idx].append(align)
for i, item in enumerate(ground_truth):
ground_truth[i] = ' '.join(item)
self.ground_truth = ground_truth
else:
raise Exception('language pair not supported.')
self.sent_pairs = list(zip(src_sents, trg_sents))
assert len(self.sent_pairs) == len(self.ground_truth)