"""
Library for looking up neuron datasets and their associated metadata by name.
"""

from neuron_explainer.activations.derived_scalars import DerivedScalarType
from neuron_explainer.pydantic import CamelCaseBaseModel, immutable


@immutable
class NeuronDatasetMetadata(CamelCaseBaseModel):
    short_name: str
    """Short name for the dataset, like "gpt2-small"."""
    derived_scalar_type: str
    """The type of scalar that the neuron records in this dataset contain, e.g. DerivedScalarType.MLP_POST_ACT."""
    user_visible_name: str
    """Name for humans to read, like "GPT-2 small"."""
    neuron_dataset_path: str
    """Path to the neuron dataset generated by collate_activations."""
    # Take care when adding new fields to this class. If they aren't optional, existing metadata
    # files will cause errors when they're read. You can make them required after ensuring that all
    # metadata files have been updated.


NEURON_DATASET_METADATA_REGISTRY = {}


def register_neuron_dataset_metadata(
    short_name: str, derived_scalar_type: str, user_visible_name: str, neuron_dataset_path: str
) -> None:
    NEURON_DATASET_METADATA_REGISTRY[
        (short_name, DerivedScalarType(derived_scalar_type))
    ] = NeuronDatasetMetadata(
        short_name=short_name,
        derived_scalar_type=derived_scalar_type,
        user_visible_name=user_visible_name,
        neuron_dataset_path=neuron_dataset_path,
    )


register_neuron_dataset_metadata(
    short_name="gpt2-xl",
    derived_scalar_type="mlp_post_act",
    user_visible_name="GPT-2 XL - MLP neurons",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/data/collated-activations/",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small",
    derived_scalar_type="mlp_post_act",
    user_visible_name="GPT-2 small - MLP neurons",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small_data/collated-activations/",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small",
    derived_scalar_type="attn_write_norm",
    user_visible_name="GPT-2 small - Attention write by token pair",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/attn_write_norm/collated-activations-by-token-pair",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small_ae-mlp-post-act-v1",
    derived_scalar_type="mlp_autoencoder_latent",
    user_visible_name="GPT-2 small - MLP autoencoder (post-act) v1",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/mlp_post_act_v1/collated-activations",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small_ae-resid-delta-mlp-v1",
    derived_scalar_type="mlp_autoencoder_latent",
    user_visible_name="GPT-2 small - MLP autoencoder (write) v1",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_mlp_v1/collated-activations",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small_ae-mlp-post-act-v4",
    derived_scalar_type="mlp_autoencoder_latent",
    user_visible_name="GPT-2 small - MLP autoencoder (post-act) v4",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/mlp_post_act_v4/collated-activations",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small_ae-resid-delta-mlp-v4",
    derived_scalar_type="mlp_autoencoder_latent",
    user_visible_name="GPT-2 small - MLP autoencoder (write) v4",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_mlp_v4/collated-activations",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small_ae-resid-delta-attn-v4",
    derived_scalar_type="attention_autoencoder_latent",
    user_visible_name="GPT-2 small - Attention autoencoder (write) v4",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_attn_v4/collated-activations",
)
register_neuron_dataset_metadata(
    short_name="gpt2-small_ae-resid-delta-attn-v4",
    derived_scalar_type="flattened_attn_write_to_latent_summed_over_heads",
    user_visible_name="GPT-2 small - Attention autoencoder (write) v4 by token pair",
    neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_attn_v4/collated-activations-by-token-pair",
)


def get_all_neuron_dataset_metadata() -> list[NeuronDatasetMetadata]:
    return list(NEURON_DATASET_METADATA_REGISTRY.values())


def get_neuron_dataset_metadata_by_short_name_and_dst(
    short_name: str, dst: DerivedScalarType
) -> NeuronDatasetMetadata:
    name_and_type = (short_name, dst)
    metadata = NEURON_DATASET_METADATA_REGISTRY.get(name_and_type)

    if metadata is None:
        error_message = f"Could not find collated activation dataset for {name_and_type}. Available datasets are: "
        error_message += ", ".join(
            f'("{short_name}", "{dst}")' for (short_name, dst) in NEURON_DATASET_METADATA_REGISTRY
        )

        if short_name.endswith("_undefined"):
            # This is likely due to the URL not providing the correct autoencoder name
            if dst in [
                DerivedScalarType.FLATTENED_ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS,
                DerivedScalarType.ATTENTION_AUTOENCODER_LATENT,
            ]:
                autoencoder_type = "attention "
            elif dst == DerivedScalarType.MLP_AUTOENCODER_LATENT:
                autoencoder_type = "mlp "
            else:
                autoencoder_type = ""
            error_message += (
                f"\nMay need to specify the {autoencoder_type}autoencoder name in the URL."
            )
        raise Exception(error_message)

    return metadata
