in MTRF/algorithms/softlearning/algorithms/multi_sac.py [0:0]
def get_diagnostics(self,
iteration,
batches,
training_paths_per_policy,
evaluation_paths_per_policy):
"""Return diagnostic information as ordered dictionary.
Also calls the `draw` method of the plotter, if plotter defined.
"""
goal_index = self._goal_index
diagnostics = {}
for i in range(self._num_goals):
self._goal_index = i
feed_dict = self._get_feed_dict(iteration, batches[i])
diagnostics.update(
self._session.run({**self._diagnostics_ops_per_goal[i]}, feed_dict))
diagnostics.update(OrderedDict([
(f'policy_{i}/{key}', value)
for key, value in self._policies[i].get_diagnostics(
flatten_input_structure({
name: batches[i]['observations'][name]
for name in self._policies[i].observation_keys})
).items()
]))
self._goal_index = goal_index
should_save = iteration % self._save_reconstruction_frequency == 0
# Generate random pixels to evaluate the preprocessors
if 'pixels' in self._placeholders['observations']:
random_idxs = np.random.choice(
feed_dict[self._placeholders['observations']['pixels']].shape[0],
size=self._n_preprocessor_evals_per_epoch)
eval_pixels = (
feed_dict[self._placeholders['observations']['pixels']][random_idxs])
else:
eval_pixels = None
if self._uses_vae and should_save:
assert eval_pixels
for i, preprocessors in enumerate(self._preprocessors_per_policy):
if self._ext_reward_coeffs[i] == 0:
continue
for name, vae in preprocessors.items():
z_mean, z_logvar, z = self._session.run(vae.encoder(eval_pixels))
reconstructions = self._session.run(
tf.math.sigmoid(vae.decoder(z)))
concat = np.concatenate([
eval_pixels,
skimage.util.img_as_ubyte(reconstructions)
], axis=2)
sampled_z = np.random.normal(
size=(eval_pixels.shape[0], vae.latent_dim))
decoded_samples = self._session.run(
tf.math.sigmoid(vae.decoder.output),
feed_dict={vae.decoder.input: sampled_z}
)
save_path = os.path.join(os.getcwd(), 'vae')
recon_concat = np.vstack(concat)
skimage.io.imsave(
os.path.join(
save_path,
f'{name}_reconstruction_{iteration}.png'),
recon_concat)
samples_concat = np.vstack(decoded_samples)
skimage.io.imsave(
os.path.join(
save_path,
f'{name}_sample_{iteration}.png'),
samples_concat)
elif self._uses_rae:
assert eval_pixels
for i, preprocessors in enumerate(self._preprocessors_per_policy):
if self._ext_reward_coeffs[i] == 0:
continue
for name, rae in preprocessors.items():
z, reconstructions = self._session.run(
rae(eval_pixels, include_reconstructions=True))
concat = np.concatenate([
eval_pixels,
skimage.util.img_as_ubyte(reconstructions)
], axis=2)
if should_save:
save_path = os.path.join(os.getcwd(), 'rae')
recon_concat = np.vstack(concat)
skimage.io.imsave(
os.path.join(
save_path,
f'{name}_reconstruction_{iteration}.png'),
recon_concat)
# Track latents
if self._fixed_eval_pixels is None:
self._fixed_eval_pixels = eval_pixels
self._fixed_eval_latents = np.zeros(z.shape)
z_fixed, reconstructions_fixed = self._session.run(
rae(self._fixed_eval_pixels, include_reconstructions=True))
if should_save:
concat_fixed = np.concatenate([
self._fixed_eval_pixels,
skimage.util.img_as_ubyte(reconstructions_fixed)
], axis=2)
recon_concat_fixed = np.vstack(concat_fixed)
skimage.io.imsave(
os.path.join(
save_path,
f'{name}_fixed_reconstruction_{iteration}.png'),
recon_concat_fixed)
z_diff = np.linalg.norm(z_fixed - self._fixed_eval_latents, axis=1)
diagnostics.update({
f'rae/{name}/tracked-latent-l2-difference-with-prev-mean': np.mean(z_diff),
f'rae/{name}/tracked-latent-l2-difference-with-prev-std': np.std(z_diff),
})
# Save the previous latents to compare in the next epoch
self._fixed_eval_latents = z_fixed
if self._save_eval_paths:
import pickle
file_name = f'eval_paths_{iteration // self.epoch_length}.pkl'
with open(os.path.join(os.getcwd(), file_name)) as f:
pickle.dump(evaluation_paths_per_policy, f)
if self._plotter:
self._plotter.draw()
return diagnostics