in align/test.py [0:0]
def test(configs, criss, dataset, simaligns, threshold=0.5):
setup_configs(configs)
os.system(f'mkdir -p {configs.save_path}')
torch.save(configs, configs.save_path + '/configs.pt')
info = collect_bitext_stats(
configs.bitext_path, configs.align_path, configs.save_path,
configs.src_lang, configs.trg_lang, configs.reversed
)
aligner = WordAligner(5 + (2 if configs.use_criss else 0), configs.hiddens, 3, 5).to(configs.device)
model_path = configs.save_path+f'/model.pt'
results = list()
aligner.load_state_dict(torch.load(model_path))
for idx, batch in enumerate(tqdm(dataset.sent_pairs)):
ss, ts = batch
ss = ss.split()
ts = ts.split()
if criss is not None:
semb = criss.embed(ss, langcode=configs.src_lang)
temb = criss.embed(ts, langcode=configs.trg_lang)
cos_matrix = cos(semb.unsqueeze(1), temb.unsqueeze(0)).unsqueeze(-1).unsqueeze(-1)
ip_matrix = (semb.unsqueeze(1) * temb.unsqueeze(0)).sum(-1).unsqueeze(-1).unsqueeze(-1)
feat_matrix = torch.cat((cos_matrix, ip_matrix), dim=-1)
word_pairs = list()
criss_features = list()
for i, sw in enumerate(ss):
for j, tw in enumerate(ts):
word_pairs.append((sw, tw))
criss_features.append(feat_matrix[i, j])
scores = extract_scores(word_pairs, criss_features, aligner, info, configs).reshape(len(ss), len(ts), -1)
scores = scores.softmax(-1)
arrange = torch.arange(3).to(configs.device).view(1, 1, -1)
scores = (scores * arrange).sum(-1)
result = inference(simaligns[idx], scores, threshold)
results.append(result)
return results