in lucid/scratch/rl_util/nmf.py [0:0]
def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
acts_h, acts_w = self.acts_reduced.shape[1:3]
zoom_h = subdiv_mult - (subdiv_mult - 1) / (acts_h + 2)
zoom_w = subdiv_mult - (subdiv_mult - 1) / (acts_w + 2)
acts_subdiv = self.acts_reduced[..., feature]
acts_subdiv = np.pad(acts_subdiv, [(0, 0), (1, 1), (1, 1)], mode="edge")
acts_subdiv = nd.zoom(acts_subdiv, [1, zoom_h, zoom_w], order=1, mode="nearest")
acts_subdiv = acts_subdiv[:, 1:-1, 1:-1]
if acts_subdiv.size == 0:
raise RuntimeError(
f"subdiv_mult of {subdiv_mult} too small for "
f"{self.acts_reduced.shape[1]}x{self.acts_reduced.shape[2]} "
"activations"
)
poses = np.indices((acts_h + 2, acts_w + 2)).transpose((1, 2, 0))
poses = nd.zoom(
poses.astype(float), [zoom_h, zoom_w, 1], order=1, mode="nearest"
)
poses = poses[1:-1, 1:-1, :] - 0.5
with np.errstate(divide="ignore"):
max_rep = np.ceil(
np.divide(
acts_subdiv.shape[1] * acts_subdiv.shape[2],
acts_subdiv.shape[0] * top_frac,
)
)
obs_indices = argmax_nd(
acts_subdiv, axes=[0], max_rep=max_rep, max_rep_strict=False
)[0]
self.pad_obses(expand_mult=expand_mult)
patches = []
patch_acts = np.zeros(obs_indices.shape)
for i in range(obs_indices.shape[0]):
patches.append([])
for j in range(obs_indices.shape[1]):
obs_index = obs_indices[i, j]
pos_h, pos_w = poses[i, j]
patch = self.get_patch(obs_index, pos_h, pos_w, expand_mult=expand_mult)
patches[i].append(patch)
patch_acts[i, j] = acts_subdiv[obs_index, i, j]
patch_acts_max = patch_acts.max()
opacities = patch_acts / (1 if patch_acts_max == 0 else patch_acts_max)
for i in range(obs_indices.shape[0]):
for j in range(obs_indices.shape[1]):
opacity = opacities[i, j][None, None, None]
opacity = opacity.repeat(patches[i][j].shape[0], axis=0)
opacity = opacity.repeat(patches[i][j].shape[1], axis=1)
patches[i][j] = np.concatenate([patches[i][j], opacity], axis=-1)
return (
np.concatenate(
[np.concatenate(patches[i], axis=1) for i in range(len(patches))],
axis=0,
),
obs_indices.tolist(),
)