in UI/utils.py [0:0]
def plot_heatmap(df, x_cols, y_cols=None, z_cols=None):
"""Plot a heatmap of the correlation
matrix of specified columns."""
y_cols = y_cols or []
z_cols = z_cols or []
columns = x_cols + y_cols + z_cols
if len(columns) < 2:
st.write("Please select at least two columns.")
return
# Compute the correlation matrix
correlation_matrix = df[columns].corr()
# Plot the heatmap
fig = px.imshow(correlation_matrix, text_auto=True, title="Heatmap")
st.plotly_chart(fig)