in neuron_explainer/activations/derived_scalars/activations_and_metadata.py [0:0]
def topk(self, k: int, largest: bool) -> tuple[torch.Tensor, list[DerivedScalarIndex]]:
# this first computes topk values and indices for each layer, then stacks them and computes topk values and indices
# the topk for the overall stack. This avoids instantiating a second copy of all the data
# in self.activations_by_layer_index
if k > self.numel():
k = self.numel() # if k > numel is requested, return everything
def get_topk_indices(activations: torch.Tensor) -> torch.Tensor:
if k >= activations.numel():
return torch.argsort(activations.flatten(), descending=largest)
else:
_, indices = torch.topk(activations.flatten(), k, largest=largest)
return indices
def get_topk_values(
activations: torch.Tensor, indices: torch.Tensor, layer_index: LayerIndex
) -> torch.Tensor:
# layer_index is unused, but required as a keyword argument
return torch.gather(activations.flatten(), 0, indices)
topk_indices = self.apply_transform_fn_to_activations(
get_topk_indices, output_dst=self.dst, output_pass_type=self.pass_type
)
topk_values = self.apply_layerwise_transform_fn_to_multiple_activations(
get_topk_values, (topk_indices,), output_dst=self.dst, output_pass_type=self.pass_type
)
topk_values_list = []
for layer_index in self.layer_indices:
topk_values_list.append(topk_values.activations_by_layer_index[layer_index])
stacked_topk_values = torch.stack(topk_values_list)
overall_topk_values, overall_topk_indices = torch.topk(
stacked_topk_values.flatten(), k, largest=largest
)
overall_topk_layer_index_indices, overall_topk_topk_indices = np.unravel_index(
overall_topk_indices.cpu().numpy(), stacked_topk_values.shape
)
overall_topk_layer_indices = [
self.layer_indices[i] for i in overall_topk_layer_index_indices
]
overall_topk_ds_indices = [
DerivedScalarIndex(
dst=self.dst,
pass_type=self.pass_type,
layer_index=layer_index,
tensor_indices=tuple(
int(x)
for x in np.unravel_index(
int(
topk_indices.activations_by_layer_index[layer_index][
overall_topk_topk_indices[i]
].item()
),
self.activations_by_layer_index[layer_index].shape,
)
), # cast from np.int64 to int
)
for i, layer_index in enumerate(overall_topk_layer_indices)
]
return overall_topk_values, overall_topk_ds_indices