in neuron_explainer/activations/derived_scalars/config.py [0:0]
def __post_init__(self) -> None:
config_setting_to_device: dict[str, torch.device] = {}
if self.model_context is not None:
assert self.model_context.model_name == self.model_name or self.model_name is None
config_setting_to_device["model"] = self.model_context.device
if self.autoencoder_context is not None:
assert (self.autoencoder_context.autoencoder_config == self.autoencoder_config) or (
self.autoencoder_config is None
)
config_setting_to_device["autoencoder"] = self.autoencoder_context.device
if self.multi_autoencoder_context is not None:
for (
node_type,
autoencoder_context,
) in self.multi_autoencoder_context.autoencoder_context_by_node_type.items():
config_setting_to_device[node_type] = autoencoder_context.device
if self.device_for_raw_activations is not None:
config_setting_to_device["raw activations"] = self.device_for_raw_activations
if len(config_setting_to_device) > 1:
assert (
len(set(config_setting_to_device.values())) == 1
), f"All devices provided must match, but {config_setting_to_device=}"
if self.node_index_for_attention_write is not None:
assert self.node_index_for_attention_write.node_type == NodeType.ATTENTION_HEAD
assert self.node_index_for_attention_write.layer_index is not None
if (
self.trace_config is not None
and self.trace_config.node_type is not NodeType.VOCAB_TOKEN
):
assert self.trace_config.layer_index is not None
if self.trace_config is not None:
# backward pass from a backward pass activation is not supported
assert self.trace_config.pass_type == PassType.FORWARD