in afa/app/app.py [0:0]
def panel_ml_visualization():
"""
"""
df = state.report["data"].get("df", None)
df_ml_results = state.report["afc"].get("df_results", None)
df_ml_preds = state.report["afc"].get("df_preds", None)
df_ml_backtests = state.report["afc"].get("df_backtests", None)
if df is None or df_ml_results is None or df_ml_preds is None:
return
freq = state.report["afc"]["freq"]
horiz = state.report["afc"]["horiz"]
start = time.time()
df_top = df.groupby(["channel", "family", "item_id"], as_index=False) \
.agg({"demand": sum}) \
.sort_values(by="demand", ascending=False)
channel_vals = [""] + sorted(df_ml_results["channel"].unique())
family_vals = [""] + sorted(df_ml_results["family"].unique())
item_id_vals = [""] + sorted(df_ml_results["item_id"].unique())
channel_index = channel_vals.index(df_top["channel"].iloc[0])
family_index = family_vals.index(df_top["family"].iloc[0])
item_id_index = item_id_vals.index(df_top["item_id"].iloc[0])
with st.beta_expander("👁️ Visualization", expanded=True):
with st.form("ml_viz_form"):
st.markdown("#### Filter By")
_cols = st.beta_columns(3)
with _cols[0]:
channel_choice = st.selectbox("Channel", channel_vals, index=channel_index, key="ml_results_channel")
with _cols[1]:
family_choice = st.selectbox("Family", family_vals, index=family_index, key="ml_results_family")
with _cols[2]:
item_id_choice = st.selectbox("Item ID", item_id_vals, index=item_id_index, key="ml_results_item")
viz_form_button = st.form_submit_button("Apply")
if viz_form_button:
pass
results_mask = \
make_mask(df_ml_results, channel_choice, family_choice, item_id_choice)
pred_mask = \
make_mask(df_ml_preds, channel_choice, family_choice, item_id_choice)
backtest_mask = \
make_mask(df_ml_backtests, channel_choice, family_choice, item_id_choice)
df_plot = df_ml_preds[pred_mask]
_df_backtests = df_ml_backtests[backtest_mask]
if len(df_plot) > 0:
# display the line chart
#fig = pex.line(df_plot, x="timestamp", y="demand", color="type")
y = df_plot.query("type == 'actual'")["demand"]
y_ts = df_plot.query("type == 'actual'")["timestamp"]
yp = df_plot.query("type == 'fcast'")["demand"]
yp_ts = df_plot.query("type == 'fcast'")["timestamp"]
fig = go.Figure()
fig.add_trace(go.Scatter(
x=y_ts, y=y, mode='lines+markers', name="actual",
fill="tozeroy", line={"width":3}, marker=dict(size=4)
))
fig.add_trace(go.Scatter(
x=yp_ts, y=np.round(yp, 0), mode='lines+markers', name="forecast",
fill="tozeroy", marker=dict(size=4)
))
fig.add_trace(go.Scatter(x=_df_backtests["timestamp"],
y=np.round(_df_backtests.demand, 0), mode="lines",
name="backtest", line_dash="dot", line_color="black"))
fig.update_layout(
margin={"t": 0, "b": 0, "r": 0, "l": 0},
height=250,
legend={"orientation": "h", "yanchor": "bottom", "y": 1.0, "xanchor":"left", "x": 0.0}
)
fig.update_xaxes(rangeslider_visible=True)
initial_range = pd.date_range(end=yp_ts.max(), periods=horiz*8, freq=freq)
initial_range = [max(initial_range[0], y_ts.min()), initial_range[-1]]
fig["layout"]["xaxis"].update(range=initial_range)
st.plotly_chart(fig, use_container_width=True)
plot_duration = time.time() - start
st.text(f"(completed in {format_timespan(plot_duration)})")
return