in mbrl/diagnostics/visualize_model_preds.py [0:0]
def create_axes(self):
num_plots = self.env.observation_space.shape[0] + 1
num_cols = int(np.ceil(np.sqrt(num_plots)))
num_rows = int(np.ceil(num_plots / num_cols))
fig, axs = plt.subplots(num_rows, num_cols)
fig.text(
0.5, 0.04, f"Time step (lookahead of {self.lookahead} steps)", ha="center"
)
fig.text(
0.04,
0.17,
"Predictions (blue/red) and ground truth (black).",
ha="center",
rotation="vertical",
)
axs = axs.reshape(-1)
lines = []
for i, ax in enumerate(axs):
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.set_xlim(0, self.lookahead)
if i < num_plots:
(real_line,) = ax.plot([], [], "k")
(model_mean_line,) = ax.plot([], [], "r" if i == num_plots - 1 else "b")
(model_ub_line,) = ax.plot(
[], [], "r" if i == num_plots - 1 else "b", linewidth=0.5
)
(model_lb_line,) = ax.plot(
[], [], "r" if i == num_plots - 1 else "b", linewidth=0.5
)
lines.append(real_line)
lines.append(model_mean_line)
lines.append(model_lb_line)
lines.append(model_ub_line)
self.fig = fig
self.axs = axs
self.lines = lines