in src/weakly_sup.py [0:0]
def get_test_lexicon(
test_set, test_lexicon, criss, lexicon_inducer, info, configs, best_threshold, best_n_cand
):
induced_lexicon = list()
pred_test_lexicon = collections.defaultdict(collections.Counter)
probs = extract_probs(
test_set, criss, lexicon_inducer, info, configs
)
for i, (x, y) in enumerate(test_set):
pred_test_lexicon[x][y] = max(pred_test_lexicon[x][y], probs[i].item())
possible_predictions = list()
for tsw in set([x[0] for x in test_lexicon]):
ssw = to_simplified(tsw)
for stw in pred_test_lexicon[ssw]:
ttw = to_traditional(stw)
pos = 1 if (tsw, ttw) in test_lexicon else 0
possible_predictions.append([tsw, ttw, pred_test_lexicon[ssw][stw], pos])
possible_predictions = sorted(possible_predictions, key=lambda x:-x[-2])
word_cnt = collections.Counter()
correct_predictions = 0
for i, item in enumerate(possible_predictions):
if item[-2] < best_threshold:
prec = correct_predictions / (sum(word_cnt.values()) + 1) * 100.0
rec = correct_predictions / len(test_lexicon) * 100.0
f1 = 2 * prec * rec / (rec + prec)
print(f'Test F1: {f1:.2f}')
break
if word_cnt[item[0]] == best_n_cand:
continue
word_cnt[item[0]] += 1
if item[-1] == 1:
correct_predictions += 1
induced_lexicon.append(item[:2])
eval_result = evaluate(induced_lexicon, test_lexicon)
return induced_lexicon, eval_result