def topk()

in neuron_explainer/activations/derived_scalars/activations_and_metadata.py [0:0]


    def topk(self, k: int, largest: bool) -> tuple[torch.Tensor, list[DerivedScalarIndex]]:
        # this first computes topk values and indices for each layer, then stacks them and computes topk values and indices
        # the topk for the overall stack. This avoids instantiating a second copy of all the data
        # in self.activations_by_layer_index
        if k > self.numel():
            k = self.numel()  # if k > numel is requested, return everything

        def get_topk_indices(activations: torch.Tensor) -> torch.Tensor:
            if k >= activations.numel():
                return torch.argsort(activations.flatten(), descending=largest)
            else:
                _, indices = torch.topk(activations.flatten(), k, largest=largest)
            return indices

        def get_topk_values(
            activations: torch.Tensor, indices: torch.Tensor, layer_index: LayerIndex
        ) -> torch.Tensor:
            # layer_index is unused, but required as a keyword argument
            return torch.gather(activations.flatten(), 0, indices)

        topk_indices = self.apply_transform_fn_to_activations(
            get_topk_indices, output_dst=self.dst, output_pass_type=self.pass_type
        )
        topk_values = self.apply_layerwise_transform_fn_to_multiple_activations(
            get_topk_values, (topk_indices,), output_dst=self.dst, output_pass_type=self.pass_type
        )

        topk_values_list = []
        for layer_index in self.layer_indices:
            topk_values_list.append(topk_values.activations_by_layer_index[layer_index])
        stacked_topk_values = torch.stack(topk_values_list)

        overall_topk_values, overall_topk_indices = torch.topk(
            stacked_topk_values.flatten(), k, largest=largest
        )

        overall_topk_layer_index_indices, overall_topk_topk_indices = np.unravel_index(
            overall_topk_indices.cpu().numpy(), stacked_topk_values.shape
        )

        overall_topk_layer_indices = [
            self.layer_indices[i] for i in overall_topk_layer_index_indices
        ]

        overall_topk_ds_indices = [
            DerivedScalarIndex(
                dst=self.dst,
                pass_type=self.pass_type,
                layer_index=layer_index,
                tensor_indices=tuple(
                    int(x)
                    for x in np.unravel_index(
                        int(
                            topk_indices.activations_by_layer_index[layer_index][
                                overall_topk_topk_indices[i]
                            ].item()
                        ),
                        self.activations_by_layer_index[layer_index].shape,
                    )
                ),  # cast from np.int64 to int
            )
            for i, layer_index in enumerate(overall_topk_layer_indices)
        ]

        return overall_topk_values, overall_topk_ds_indices