def main()

in src/suggest_cls_streamlit.py [0:0]


def main():
    st.title("Classification Visualization App")

    # Attempt to load windows from your CSVs
    all_windows = load_all_windows()

    # If none found, use dummy data; otherwise override with loaded data.
    windows = [
        {
            "titles": ["Sausage", "Burger", "Salads", "Github", "Cat Animal", "Dogs"],
            "group_name": ["Food", "Food", "Food", "Ungrouped-4", "Animals", "Animals"]
        }
    ]
    if len(all_windows) > 0:
        windows = all_windows

    # UI Control: pick which window
    window_idx = st.selectbox("Select Window Index:", list(range(len(windows))))

    window = windows[window_idx]

    # Build a list of "Title (Group)" options for anchors
    title_group_options = [
        f"{t} ({g})"
        for t, g in zip(window["titles"], window["group_name"])
    ]
    title_group_map = {
        f"{t} ({g})": t
        for t, g in zip(window["titles"], window["group_name"])
    }

    # Multi-select anchors
    anchor_labels = st.multiselect(
        "Select Anchor Titles (Group appended)",
        options=title_group_options,
        default=[]
    )
    anchors = [title_group_map[label] for label in anchor_labels]

    # Classifier hyperparameters
    coef_1 = st.number_input("Coefficient for group similarity", value=28.488064)
    coef_2 = st.number_input("Coefficient for title similarity", value=17.99544)
    intercept_val = st.number_input("Intercept", value=-37.541557)
    classifier_params = [[coef_1, coef_2], intercept_val]

    # Threshold
    threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, value=0.5)

    # Model name
    model_name = st.text_input(
        "SentenceTransformer Model Name",
        value="thenlper/gte-small"  # or "sentence-transformers/all-MiniLM-L6-v2"
    )

    # Run classification
    if st.button("Run Classification"):
        if not anchors:
            st.error("No anchors selected. Please select at least one title as an anchor.")
            return

        try:
            embedding_model = SentenceTransformer(model_name)
        except Exception as e:
            st.error(f"Error loading model: {e}")
            return

        # For classification, we just need at least one anchor to identify group
        anchor_indices = [window["titles"].index(a) for a in anchors]
        anchor_group_name = window['group_name'][anchor_indices[0]]

        st.write(f"**Anchor Titles**: {anchors}")
        st.write(f"**Anchor Group Name**: {anchor_group_name}")

        # Compute results
        results = get_classifications(
            window,
            anchors,
            embedding_model,
            classifier_params,
            threshold
        )

        # Build DataFrame
        df = pd.DataFrame(
            [
                {
                    "Candidate Title": r[0],
                    "Index": r[1]["index"],
                    "Group Similarity": r[1]["group_similarity"],
                    "Title Similarity": r[1]["title_similarity"],
                    "Probability": r[1]["proba"],
                    "Similar?": r[1]["similar"],
                    "Group Name": r[1]["group_name"],
                    "Classification": r[1]["classification"],
                }
                for r in results
            ]
        )

        # Classification color map
        def classification_row_style(row):
            color_map = {
                'tp': 'green',
                'fp': 'red',
                'fn': 'lightcoral',   # 'light red'
                'tn': 'lightgreen'
            }
            return [
                f'background-color: {color_map.get(row["Classification"], "white")}'
                for _ in row.index
            ]

        styled_df = df.style.apply(classification_row_style, axis=1)

        # Make the DataFrame bigger: set width & height explicitly
        st.dataframe(styled_df, width=1500, height=500)

        # Classification counts
        classification_counts = df["Classification"].value_counts()
        st.write("### Classification Counts")
        st.bar_chart(classification_counts)
        st.write(classification_counts)