in mbrl/diagnostics/eval_model_on_dataset.py [0:0]
def plot_dataset_results(self, dataset: mbrl.util.TransitionIterator):
all_means: List[np.ndarray] = []
all_targets = []
# Iterating over dataset and computing predictions
for batch in dataset:
(
outputs,
target,
) = self.dynamics_model.get_output_and_targets(batch)
all_means.append(outputs[0].cpu().numpy())
all_targets.append(target.cpu().numpy())
# Consolidating targets and predictions
all_means_np = np.concatenate(all_means, axis=-2)
targets_np = np.concatenate(all_targets, axis=0)
if all_means_np.ndim == 2:
all_means_np = all_means_np[np.newaxis, :]
assert all_means_np.ndim == 3 # ensemble, batch, target_dim
# Visualization
num_dim = targets_np.shape[1]
for dim in range(num_dim):
sort_idx = targets_np[:, dim].argsort()
subsample_size = len(sort_idx) // 20 + 1
subsample = np.random.choice(len(sort_idx), size=(subsample_size,))
means = all_means_np[..., sort_idx, dim][..., subsample] # type: ignore
target = targets_np[sort_idx, dim][subsample]
plt.figure(figsize=(8, 8))
for i in range(all_means_np.shape[0]):
plt.plot(target, means[i], ".", markersize=2)
mean_of_means = means.mean(0)
mean_sort_idx = target.argsort()
plt.plot(
target[mean_sort_idx],
mean_of_means[mean_sort_idx],
color="r",
linewidth=0.5,
)
plt.plot(
[target.min(), target.max()],
[target.min(), target.max()],
linewidth=2,
color="k",
)
plt.xlabel("Target")
plt.ylabel("Prediction")
fname = self.output_path / f"pred_dim{dim}.png"
plt.savefig(fname)
plt.close()