in src/suggest_cls_streamlit.py [0:0]
def get_classifications(window, anchors, embedding_model, classifier_params, threshold):
"""
window: dict with at least
- 'titles': list of strings
- 'group_name': list of strings (same length as titles)
anchors: list of anchor titles
embedding_model: a SentenceTransformer model
classifier_params: list or tuple [[coef_1, coef_2], intercept]
threshold: float in [0,1]
"""
anchors = [preprocess_text(a) for a in anchors]
window_titles = [preprocess_text(t) for t in window['titles']]
# Convert anchor titles to indices
anchor_indices = [window_titles.index(a) for a in anchors]
anchor_group_name = window['group_name'][anchor_indices[0]]
candidate_indices = [i for i, c in enumerate(window_titles) if i not in anchor_indices]
candidate_titles = [window_titles[i] for i in candidate_indices]
candidate_groups = [window['group_name'][i] for i in candidate_indices]
ct = {}
# Group embedding (from first anchor's group name)
group_embedding = embedding_model.encode(anchor_group_name)
# Anchor title embeddings
anchor_title_embeddings = [embedding_model.encode(a) for a in anchors]
for cg, c, ci in zip(candidate_groups, candidate_titles, candidate_indices):
ct[c] = {}
ct[c]['index'] = ci
candidate_title_embedding = embedding_model.encode(c)
ct[c]['group_similarity'] = cosine_similarity(
group_embedding.reshape(1, -1),
candidate_title_embedding.reshape(1, -1)
)[0][0]
# Average similarity to all anchor titles
average_anchor_similarity = 0
for anchor_title_embedding in anchor_title_embeddings:
average_anchor_similarity += cosine_similarity(
candidate_title_embedding.reshape(1, -1),
anchor_title_embedding.reshape(1, -1)
)[0][0]
ct[c]['title_similarity'] = average_anchor_similarity / len(anchor_title_embeddings)
# Probability from logistic formula
coef, intercept = classifier_params[0], classifier_params[1]
ct[c]['proba'] = prod(
ct[c]['group_similarity'],
ct[c]['title_similarity'],
coef,
intercept
)
# Compare with threshold
ct[c]['similar'] = ct[c]['proba'] > threshold
# True label check (if same group = "true" label)
ct[c]['group_name'] = cg
if ct[c]['similar']:
# check if true positive or false positive
ct[c]['classification'] = 'tp' if cg == anchor_group_name else 'fp'
else:
# false negative or true negative
ct[c]['classification'] = 'fn' if cg == anchor_group_name else 'tn'
# Return sorted list (descending by probability)
return sorted(list(ct.items()), key=lambda x: x[1]['proba'], reverse=True)