def _get_output_weight()

in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]


    def _get_output_weight(self, ds_index: DerivedScalarIndex) -> torch.Tensor:
        if ds_index.dst.is_autoencoder_latent:
            assert self._multi_autoencoder_context is not None
            autoencoder_context = self._multi_autoencoder_context.get_autoencoder_context(
                ds_index.dst.node_type
            )
            assert autoencoder_context is not None
            output_dst = autoencoder_context.dst
        elif ds_index.dst.node_type == NodeType.ATTENTION_HEAD:
            output_dst = DerivedScalarType.ATTN_WEIGHTED_VALUE
        else:
            output_dst = ds_index.dst

        _output_weight_by_dst: dict[DerivedScalarType, WeightLocationType] = {
            DerivedScalarType.MLP_POST_ACT: WeightLocationType.MLP_TO_RESIDUAL,
            DerivedScalarType.ATTN_WEIGHTED_VALUE: WeightLocationType.ATTN_TO_RESIDUAL,
        }

        assert output_dst in _output_weight_by_dst, f"{output_dst} must be in output weight dict"
        output_weight_location_type = _output_weight_by_dst[output_dst]
        weight_shape_spec = output_weight_location_type.shape_spec

        weight_tensor_indices = tuple(
            [ds_index.tensor_index_by_dim.get(dim, None) for dim in weight_shape_spec]
        )
        weight_tensor_slices: tuple[slice | int | None, ...] = tuple(
            [slice(None) if index is None else index for index in weight_tensor_indices]
        )

        return self._model_context.get_weight(output_weight_location_type, ds_index.layer_index)[
            weight_tensor_slices
        ]