def plot_heatmap()

in UI/utils.py [0:0]


def plot_heatmap(df, x_cols, y_cols=None, z_cols=None):
    """Plot a heatmap of the correlation
    matrix of specified columns."""

    y_cols = y_cols or []
    z_cols = z_cols or []
    columns = x_cols + y_cols + z_cols

    if len(columns) < 2:
        st.write("Please select at least two columns.")
        return

    # Compute the correlation matrix
    correlation_matrix = df[columns].corr()

    # Plot the heatmap
    fig = px.imshow(correlation_matrix, text_auto=True, title="Heatmap")
    st.plotly_chart(fig)