in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]
def _get_output_weight(self, ds_index: DerivedScalarIndex) -> torch.Tensor:
if ds_index.dst.is_autoencoder_latent:
assert self._multi_autoencoder_context is not None
autoencoder_context = self._multi_autoencoder_context.get_autoencoder_context(
ds_index.dst.node_type
)
assert autoencoder_context is not None
output_dst = autoencoder_context.dst
elif ds_index.dst.node_type == NodeType.ATTENTION_HEAD:
output_dst = DerivedScalarType.ATTN_WEIGHTED_VALUE
else:
output_dst = ds_index.dst
_output_weight_by_dst: dict[DerivedScalarType, WeightLocationType] = {
DerivedScalarType.MLP_POST_ACT: WeightLocationType.MLP_TO_RESIDUAL,
DerivedScalarType.ATTN_WEIGHTED_VALUE: WeightLocationType.ATTN_TO_RESIDUAL,
}
assert output_dst in _output_weight_by_dst, f"{output_dst} must be in output weight dict"
output_weight_location_type = _output_weight_by_dst[output_dst]
weight_shape_spec = output_weight_location_type.shape_spec
weight_tensor_indices = tuple(
[ds_index.tensor_index_by_dim.get(dim, None) for dim in weight_shape_spec]
)
weight_tensor_slices: tuple[slice | int | None, ...] = tuple(
[slice(None) if index is None else index for index in weight_tensor_indices]
)
return self._model_context.get_weight(output_weight_location_type, ds_index.layer_index)[
weight_tensor_slices
]