def vis_dataset()

in lucid/scratch/rl_util/nmf.py [0:0]


    def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
        acts_h, acts_w = self.acts_reduced.shape[1:3]
        zoom_h = subdiv_mult - (subdiv_mult - 1) / (acts_h + 2)
        zoom_w = subdiv_mult - (subdiv_mult - 1) / (acts_w + 2)
        acts_subdiv = self.acts_reduced[..., feature]
        acts_subdiv = np.pad(acts_subdiv, [(0, 0), (1, 1), (1, 1)], mode="edge")
        acts_subdiv = nd.zoom(acts_subdiv, [1, zoom_h, zoom_w], order=1, mode="nearest")
        acts_subdiv = acts_subdiv[:, 1:-1, 1:-1]
        if acts_subdiv.size == 0:
            raise RuntimeError(
                f"subdiv_mult of {subdiv_mult} too small for "
                f"{self.acts_reduced.shape[1]}x{self.acts_reduced.shape[2]} "
                "activations"
            )
        poses = np.indices((acts_h + 2, acts_w + 2)).transpose((1, 2, 0))
        poses = nd.zoom(
            poses.astype(float), [zoom_h, zoom_w, 1], order=1, mode="nearest"
        )
        poses = poses[1:-1, 1:-1, :] - 0.5
        with np.errstate(divide="ignore"):
            max_rep = np.ceil(
                np.divide(
                    acts_subdiv.shape[1] * acts_subdiv.shape[2],
                    acts_subdiv.shape[0] * top_frac,
                )
            )
        obs_indices = argmax_nd(
            acts_subdiv, axes=[0], max_rep=max_rep, max_rep_strict=False
        )[0]
        self.pad_obses(expand_mult=expand_mult)
        patches = []
        patch_acts = np.zeros(obs_indices.shape)
        for i in range(obs_indices.shape[0]):
            patches.append([])
            for j in range(obs_indices.shape[1]):
                obs_index = obs_indices[i, j]
                pos_h, pos_w = poses[i, j]
                patch = self.get_patch(obs_index, pos_h, pos_w, expand_mult=expand_mult)
                patches[i].append(patch)
                patch_acts[i, j] = acts_subdiv[obs_index, i, j]
        patch_acts_max = patch_acts.max()
        opacities = patch_acts / (1 if patch_acts_max == 0 else patch_acts_max)
        for i in range(obs_indices.shape[0]):
            for j in range(obs_indices.shape[1]):
                opacity = opacities[i, j][None, None, None]
                opacity = opacity.repeat(patches[i][j].shape[0], axis=0)
                opacity = opacity.repeat(patches[i][j].shape[1], axis=1)
                patches[i][j] = np.concatenate([patches[i][j], opacity], axis=-1)
        return (
            np.concatenate(
                [np.concatenate(patches[i], axis=1) for i in range(len(patches))],
                axis=0,
            ),
            obs_indices.tolist(),
        )