def test()

in align/test.py [0:0]


def test(configs, criss, dataset, simaligns, threshold=0.5):
    setup_configs(configs)
    os.system(f'mkdir -p {configs.save_path}')
    torch.save(configs, configs.save_path + '/configs.pt')
    info = collect_bitext_stats(
        configs.bitext_path, configs.align_path, configs.save_path, 
        configs.src_lang, configs.trg_lang, configs.reversed
    )
    aligner = WordAligner(5 + (2 if configs.use_criss else 0), configs.hiddens, 3, 5).to(configs.device)
    model_path = configs.save_path+f'/model.pt'
    results = list()
    aligner.load_state_dict(torch.load(model_path))
    for idx, batch in enumerate(tqdm(dataset.sent_pairs)):
        ss, ts = batch
        ss = ss.split()
        ts = ts.split()
        if criss is not None:
            semb = criss.embed(ss, langcode=configs.src_lang)
            temb = criss.embed(ts, langcode=configs.trg_lang)
            cos_matrix = cos(semb.unsqueeze(1), temb.unsqueeze(0)).unsqueeze(-1).unsqueeze(-1)
            ip_matrix = (semb.unsqueeze(1) * temb.unsqueeze(0)).sum(-1).unsqueeze(-1).unsqueeze(-1)
            feat_matrix = torch.cat((cos_matrix, ip_matrix), dim=-1)
        word_pairs = list()
        criss_features = list()
        for i, sw in enumerate(ss):
            for j, tw in enumerate(ts):
                word_pairs.append((sw, tw))
                criss_features.append(feat_matrix[i, j])
        scores = extract_scores(word_pairs, criss_features, aligner, info, configs).reshape(len(ss), len(ts), -1)
        scores = scores.softmax(-1)
        arrange = torch.arange(3).to(configs.device).view(1, 1, -1)
        scores = (scores * arrange).sum(-1)
        result = inference(simaligns[idx], scores, threshold)
        results.append(result)
    return results