def panel_ml_visualization()

in afa/app/app.py [0:0]


def panel_ml_visualization():
    """
    """

    df = state.report["data"].get("df", None)
    df_ml_results = state.report["afc"].get("df_results", None)
    df_ml_preds = state.report["afc"].get("df_preds", None)
    df_ml_backtests = state.report["afc"].get("df_backtests", None)

    if df is None or df_ml_results is None or df_ml_preds is None:
        return

    freq = state.report["afc"]["freq"]
    horiz = state.report["afc"]["horiz"]
    start = time.time()

    df_top = df.groupby(["channel", "family", "item_id"], as_index=False) \
               .agg({"demand": sum}) \
               .sort_values(by="demand", ascending=False)

    channel_vals = [""] + sorted(df_ml_results["channel"].unique())
    family_vals = [""] + sorted(df_ml_results["family"].unique())
    item_id_vals = [""] + sorted(df_ml_results["item_id"].unique())

    channel_index = channel_vals.index(df_top["channel"].iloc[0])
    family_index = family_vals.index(df_top["family"].iloc[0])
    item_id_index = item_id_vals.index(df_top["item_id"].iloc[0])

    with st.beta_expander("👁️  Visualization", expanded=True):
        with st.form("ml_viz_form"):
            st.markdown("#### Filter By")
            _cols = st.beta_columns(3)

            with _cols[0]:
                channel_choice = st.selectbox("Channel", channel_vals, index=channel_index, key="ml_results_channel")

            with _cols[1]:
                family_choice = st.selectbox("Family", family_vals, index=family_index, key="ml_results_family")

            with _cols[2]:
                item_id_choice = st.selectbox("Item ID", item_id_vals, index=item_id_index, key="ml_results_item")

            viz_form_button = st.form_submit_button("Apply")

        if viz_form_button:
            pass

        results_mask = \
            make_mask(df_ml_results, channel_choice, family_choice, item_id_choice)
        pred_mask = \
            make_mask(df_ml_preds, channel_choice, family_choice, item_id_choice)
        backtest_mask = \
            make_mask(df_ml_backtests, channel_choice, family_choice, item_id_choice)

        df_plot = df_ml_preds[pred_mask]
        _df_backtests = df_ml_backtests[backtest_mask]

        if len(df_plot) > 0:
            # display the line chart
            #fig = pex.line(df_plot, x="timestamp", y="demand", color="type")

            y = df_plot.query("type == 'actual'")["demand"]
            y_ts = df_plot.query("type == 'actual'")["timestamp"]

            yp = df_plot.query("type == 'fcast'")["demand"]
            yp_ts = df_plot.query("type == 'fcast'")["timestamp"]

            fig = go.Figure()
            fig.add_trace(go.Scatter(
                x=y_ts, y=y, mode='lines+markers', name="actual",
                fill="tozeroy", line={"width":3}, marker=dict(size=4)
            ))

            fig.add_trace(go.Scatter(
                x=yp_ts, y=np.round(yp, 0), mode='lines+markers', name="forecast",
                fill="tozeroy", marker=dict(size=4)
            ))

            fig.add_trace(go.Scatter(x=_df_backtests["timestamp"],
                y=np.round(_df_backtests.demand, 0), mode="lines",
                name="backtest", line_dash="dot", line_color="black"))
            fig.update_layout(
                margin={"t": 0, "b": 0, "r": 0, "l": 0},
                height=250,
                legend={"orientation": "h", "yanchor": "bottom", "y": 1.0, "xanchor":"left", "x": 0.0}
            )

            fig.update_xaxes(rangeslider_visible=True)

            initial_range = pd.date_range(end=yp_ts.max(), periods=horiz*8, freq=freq)
            initial_range = [max(initial_range[0], y_ts.min()), initial_range[-1]]

            fig["layout"]["xaxis"].update(range=initial_range)
            st.plotly_chart(fig, use_container_width=True)

        plot_duration = time.time() - start
        st.text(f"(completed in {format_timespan(plot_duration)})")

    return