in neuron_explainer/activations/derived_scalars/indexing.py [0:0]
def tensor_index_by_dim(self) -> dict[Dimension, AllOrOneIndex]:
# copied from DerivedScalarIndex; TODO: ActivationIndex and DerivedScalarIndex inherit from a shared base class,
# and perhaps likewise with DerivedScalarType and ActivationLocationType?
tensor_indices_list = list(self.tensor_indices)
assert len(tensor_indices_list) <= len(
self.activation_location_type.shape_spec_per_token_sequence
), (
f"Too many tensor indices {tensor_indices_list} for "
f"{self.activation_location_type.shape_spec_per_token_sequence=}"
)
tensor_indices_list.extend(
["All"]
* (
len(self.activation_location_type.shape_spec_per_token_sequence)
- len(self.tensor_indices)
)
)
assert len(tensor_indices_list) == len(
self.activation_location_type.shape_spec_per_token_sequence
)
return dict(
zip(
self.activation_location_type.shape_spec_per_token_sequence,
tensor_indices_list,
)
)