in eval-vis-model.py [0:0]
def plot_ctrl(self, plan_us, fname):
assert plan_us.ndimension() == 2
T, nctrl = plan_us.size()
domain_name = self.domain_name
gridspec_kw = {'wspace': 0, 'hspace': 0}
if self.args.vid_mode == 'highlight':
# fig, axs = plt.subplots(1, 8, figsize=(20, 1.5), gridspec_kw=gridspec_kw)
fig, axs = plt.subplots(1, 8, figsize=(8, 1), gridspec_kw=gridspec_kw)
elif domain_name in ['Humanoid-v2', 'mbpo_humanoid']:
fig, axs = plt.subplots(3, 6, figsize=(16, 4), gridspec_kw=gridspec_kw)
elif 'humanoid' in domain_name:
fig, axs = plt.subplots(3, 7, figsize=(16, 4), gridspec_kw=gridspec_kw)
elif 'pendulum' in domain_name:
fig, axs = plt.subplots(1, 1, figsize=(6, 2.5), gridspec_kw=gridspec_kw)
else:
fig, axs = plt.subplots(1, nctrl, figsize=(16, 2), gridspec_kw=gridspec_kw)
if nctrl > 1:
axs = axs.ravel()
else:
axs = [axs]
# for ax in axs: ax.axis('off')
color = plt.rcParams['axes.prop_cycle'].by_key()['color'][2]
for i in range(nctrl):
if i > len(axs)-1:
# print(f'Warning: Skipping action dim {i}')
continue
ax = axs[i]
ax.plot(utils.to_np(plan_us[:, i]), color=color)
ax.set_ylim(-1., 1.)
ax.set_xlim(0, plan_us.shape[0]-1)
ax.get_xaxis().set_ticklabels([])
ax.get_yaxis().set_ticklabels([])
ax.axhline(color='k', linestyle='--', alpha=0.4)
for i in range(nctrl, len(axs)):
ax = axs[i]
ax.set_axis_off()
if 'pendulum' in domain_name or self.args.vid_mode == 'highlight':
fontsize = 20
else:
fontsize = 14
if self.args.vid_mode != 'highlight':
add_label(axs[0], 'Actions', fontsize)
fig.tight_layout()
fig.savefig(fname)
plt.close(fig)