def get_classifications()

in src/suggest_cls_streamlit.py [0:0]


def get_classifications(window, anchors, embedding_model, classifier_params, threshold):
    """
    window: dict with at least
      - 'titles': list of strings
      - 'group_name': list of strings (same length as titles)
    anchors: list of anchor titles
    embedding_model: a SentenceTransformer model
    classifier_params: list or tuple [[coef_1, coef_2], intercept]
    threshold: float in [0,1]
    """

    anchors = [preprocess_text(a) for a in anchors]
    window_titles = [preprocess_text(t) for t in window['titles']]
    # Convert anchor titles to indices
    anchor_indices = [window_titles.index(a) for a in anchors]
    anchor_group_name = window['group_name'][anchor_indices[0]]

    candidate_indices = [i for i, c in enumerate(window_titles) if i not in anchor_indices]
    candidate_titles = [window_titles[i] for i in candidate_indices]
    candidate_groups = [window['group_name'][i] for i in candidate_indices]

    ct = {}

    # Group embedding (from first anchor's group name)
    group_embedding = embedding_model.encode(anchor_group_name)
    # Anchor title embeddings
    anchor_title_embeddings = [embedding_model.encode(a) for a in anchors]

    for cg, c, ci in zip(candidate_groups, candidate_titles, candidate_indices):
        ct[c] = {}
        ct[c]['index'] = ci
        candidate_title_embedding = embedding_model.encode(c)
        ct[c]['group_similarity'] = cosine_similarity(
            group_embedding.reshape(1, -1),
            candidate_title_embedding.reshape(1, -1)
        )[0][0]

        # Average similarity to all anchor titles
        average_anchor_similarity = 0
        for anchor_title_embedding in anchor_title_embeddings:
            average_anchor_similarity += cosine_similarity(
                candidate_title_embedding.reshape(1, -1),
                anchor_title_embedding.reshape(1, -1)
            )[0][0]
        ct[c]['title_similarity'] = average_anchor_similarity / len(anchor_title_embeddings)

        # Probability from logistic formula
        coef, intercept = classifier_params[0], classifier_params[1]
        ct[c]['proba'] = prod(
            ct[c]['group_similarity'],
            ct[c]['title_similarity'],
            coef,
            intercept
        )
        # Compare with threshold
        ct[c]['similar'] = ct[c]['proba'] > threshold

        # True label check (if same group = "true" label)
        ct[c]['group_name'] = cg
        if ct[c]['similar']:
            # check if true positive or false positive
            ct[c]['classification'] = 'tp' if cg == anchor_group_name else 'fp'
        else:
            # false negative or true negative
            ct[c]['classification'] = 'fn' if cg == anchor_group_name else 'tn'

    # Return sorted list (descending by probability)
    return sorted(list(ct.items()), key=lambda x: x[1]['proba'], reverse=True)