def __post_init__()

in neuron_explainer/activations/derived_scalars/config.py [0:0]


    def __post_init__(self) -> None:
        config_setting_to_device: dict[str, torch.device] = {}
        if self.model_context is not None:
            assert self.model_context.model_name == self.model_name or self.model_name is None
            config_setting_to_device["model"] = self.model_context.device

        if self.autoencoder_context is not None:
            assert (self.autoencoder_context.autoencoder_config == self.autoencoder_config) or (
                self.autoencoder_config is None
            )
            config_setting_to_device["autoencoder"] = self.autoencoder_context.device

        if self.multi_autoencoder_context is not None:
            for (
                node_type,
                autoencoder_context,
            ) in self.multi_autoencoder_context.autoencoder_context_by_node_type.items():
                config_setting_to_device[node_type] = autoencoder_context.device

        if self.device_for_raw_activations is not None:
            config_setting_to_device["raw activations"] = self.device_for_raw_activations

        if len(config_setting_to_device) > 1:
            assert (
                len(set(config_setting_to_device.values())) == 1
            ), f"All devices provided must match, but {config_setting_to_device=}"

        if self.node_index_for_attention_write is not None:
            assert self.node_index_for_attention_write.node_type == NodeType.ATTENTION_HEAD
            assert self.node_index_for_attention_write.layer_index is not None
            if (
                self.trace_config is not None
                and self.trace_config.node_type is not NodeType.VOCAB_TOKEN
            ):
                assert self.trace_config.layer_index is not None

        if self.trace_config is not None:
            # backward pass from a backward pass activation is not supported
            assert self.trace_config.pass_type == PassType.FORWARD