in eval-vis-model.py [0:0]
def plot_obs_rew(self, true_xs, pred_xs, true_rews, pred_rews, pred_dones, fname):
domain_name = self.domain_name
bounds = (-self.max_obs, self.max_obs)
reward_bounds = self.reward_bounds
gridspec_kw = {'wspace': 0, 'hspace': 0}
if self.args.vid_mode == 'highlight':
# fig, axs = plt.subplots(2, 3, figsize=(4, 3), gridspec_kw=gridspec_kw)
fig, axs = plt.subplots(3, 4, figsize=(4, 3), gridspec_kw=gridspec_kw)
elif 'cheetah' in domain_name.lower():
fig, axs = plt.subplots(3, 6, figsize=(14, 10), gridspec_kw=gridspec_kw)
elif 'walker' in domain_name:
fig, axs = plt.subplots(5, 5, figsize=(14, 10), gridspec_kw=gridspec_kw)
elif domain_name == 'mbpo_humanoid':
fig, axs = plt.subplots(6, 8, figsize=(16, 10), gridspec_kw=gridspec_kw)
elif domain_name == 'mbpo_ant':
fig, axs = plt.subplots(5, 6, figsize=(14, 10), gridspec_kw=gridspec_kw)
elif 'humanoid' in domain_name.lower():
fig, axs = plt.subplots(8, 9, figsize=(16, 10), gridspec_kw=gridspec_kw)
elif 'pendulum' in domain_name:
fig, axs = plt.subplots(4, 1, figsize=(6, 10), gridspec_kw=gridspec_kw)
elif 'hopper' in domain_name:
fig, axs = plt.subplots(4, 3, figsize=(14, 10), gridspec_kw=gridspec_kw)
elif 'swimmer' in domain_name:
fig, axs = plt.subplots(3, 3, figsize=(10, 10), gridspec_kw=gridspec_kw)
else:
fig, axs = plt.subplots(5, 5, figsize=(14, 10), gridspec_kw=gridspec_kw)
axs = axs.ravel()
if self.args.vid_mode != 'highlight':
add_label(axs[0], 'States', fontsize=20)
for ax in axs:
# ax.axis('off')
ax.get_xaxis().set_ticklabels([])
ax.get_yaxis().set_ticklabels([])
ax.patch.set_edgecolor('black')
horizon_p1, state_dim = pred_xs[0].shape
horizon = horizon_p1-1
for i in range(state_dim):
if i >= len(axs)-1:
# print(f'Warning: Skipping state dim {i}')
continue
ax = axs[i]
if true_xs is not None:
ax.plot(true_xs[:, i], color='k', label='Ground Truth')
color = None
for j in range(len(pred_xs)):
p, = ax.plot(utils.to_np(pred_xs[j][:, i]), alpha=1., color=color)
color = p.get_color()
ax.set_ylim(bounds[0][i], bounds[1][i])
ax.set_xlim(0, horizon)
if self.args.vid_mode != 'highlight':
ax.axhline(color='k', linestyle='--', alpha=0.4)
rew_ax = axs[-1]
if true_rews is not None:
rew_ax.plot(true_rews, alpha=0.5, color='k')
color = plt.rcParams['axes.prop_cycle'].by_key()['color'][1]
for j in range(len(pred_rews)):
rew_ax.plot(utils.to_np(pred_rews[j]), alpha=1., color=color)
rew_ax.set_ylim(*reward_bounds)
rew_ax.set_xlim(0, horizon)
rew_ax.get_xaxis().set_ticklabels([])
rew_ax.get_yaxis().set_ticklabels([])
if self.args.vid_mode != 'highlight':
add_label(rew_ax, 'Rewards', fontsize=20)
if self.args.show_dones:
done_ax = rew_ax.twinx()
for j in range(len(pred_dones)):
done_ax.plot(utils.to_np(pred_dones[j]), alpha=1.)
done_ax.set_ylim(-0.1, 1.1)
done_ax.set_xlim(0, horizon)
done_ax.get_xaxis().set_ticklabels([])
done_ax.get_yaxis().set_ticklabels([])
fig.tight_layout()
fig.savefig(fname)
plt.close(fig)