in src/suggest_cls_streamlit.py [0:0]
def main():
st.title("Classification Visualization App")
# Attempt to load windows from your CSVs
all_windows = load_all_windows()
# If none found, use dummy data; otherwise override with loaded data.
windows = [
{
"titles": ["Sausage", "Burger", "Salads", "Github", "Cat Animal", "Dogs"],
"group_name": ["Food", "Food", "Food", "Ungrouped-4", "Animals", "Animals"]
}
]
if len(all_windows) > 0:
windows = all_windows
# UI Control: pick which window
window_idx = st.selectbox("Select Window Index:", list(range(len(windows))))
window = windows[window_idx]
# Build a list of "Title (Group)" options for anchors
title_group_options = [
f"{t} ({g})"
for t, g in zip(window["titles"], window["group_name"])
]
title_group_map = {
f"{t} ({g})": t
for t, g in zip(window["titles"], window["group_name"])
}
# Multi-select anchors
anchor_labels = st.multiselect(
"Select Anchor Titles (Group appended)",
options=title_group_options,
default=[]
)
anchors = [title_group_map[label] for label in anchor_labels]
# Classifier hyperparameters
coef_1 = st.number_input("Coefficient for group similarity", value=28.488064)
coef_2 = st.number_input("Coefficient for title similarity", value=17.99544)
intercept_val = st.number_input("Intercept", value=-37.541557)
classifier_params = [[coef_1, coef_2], intercept_val]
# Threshold
threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, value=0.5)
# Model name
model_name = st.text_input(
"SentenceTransformer Model Name",
value="thenlper/gte-small" # or "sentence-transformers/all-MiniLM-L6-v2"
)
# Run classification
if st.button("Run Classification"):
if not anchors:
st.error("No anchors selected. Please select at least one title as an anchor.")
return
try:
embedding_model = SentenceTransformer(model_name)
except Exception as e:
st.error(f"Error loading model: {e}")
return
# For classification, we just need at least one anchor to identify group
anchor_indices = [window["titles"].index(a) for a in anchors]
anchor_group_name = window['group_name'][anchor_indices[0]]
st.write(f"**Anchor Titles**: {anchors}")
st.write(f"**Anchor Group Name**: {anchor_group_name}")
# Compute results
results = get_classifications(
window,
anchors,
embedding_model,
classifier_params,
threshold
)
# Build DataFrame
df = pd.DataFrame(
[
{
"Candidate Title": r[0],
"Index": r[1]["index"],
"Group Similarity": r[1]["group_similarity"],
"Title Similarity": r[1]["title_similarity"],
"Probability": r[1]["proba"],
"Similar?": r[1]["similar"],
"Group Name": r[1]["group_name"],
"Classification": r[1]["classification"],
}
for r in results
]
)
# Classification color map
def classification_row_style(row):
color_map = {
'tp': 'green',
'fp': 'red',
'fn': 'lightcoral', # 'light red'
'tn': 'lightgreen'
}
return [
f'background-color: {color_map.get(row["Classification"], "white")}'
for _ in row.index
]
styled_df = df.style.apply(classification_row_style, axis=1)
# Make the DataFrame bigger: set width & height explicitly
st.dataframe(styled_df, width=1500, height=500)
# Classification counts
classification_counts = df["Classification"].value_counts()
st.write("### Classification Counts")
st.bar_chart(classification_counts)
st.write(classification_counts)