import os

from neuron_explainer.activation_server.load_neurons import convert_dataset_path_to_short_name

# Maps from neuron dataset path to explanation dataset path.
AZURE_EXPLANATION_DATASET_REGISTRY = {
    "https://openaipublic.blob.core.windows.net/neuron-explainer/data/collated-activations/": "https://openaipublic.blob.core.windows.net/neuron-explainer/data/explanations/",
    "https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small_data/collated-activations/": "https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small_data/explanations/",
}


def get_local_cached_explanation_directory(dataset_path: str) -> str:
    root_project_directory = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    )
    dataset_short_name = convert_dataset_path_to_short_name(dataset_path)
    return f"{root_project_directory}/cached_explanations/{dataset_short_name}"


async def get_all_explanation_datasets(neuron_dataset: str) -> list[str]:
    """
    Get all explanation datasets for a given neuron dataset. Search the public azure bucket and also
    the local filesystem cache. Returns a list of paths to the explanation datasets.
    Path can be an azure path (beginning with `https://`) or a local path.
    """
    datasets = []
    if neuron_dataset in AZURE_EXPLANATION_DATASET_REGISTRY:
        datasets.append(AZURE_EXPLANATION_DATASET_REGISTRY[neuron_dataset])
    local_cache_dir = get_local_cached_explanation_directory(neuron_dataset)
    # Iterate through folders to get a list of dirs.
    # There will be different local cache directories if the user generates scored explanations for
    # the same neuron dataset using different neuron/attention explainer registry entries (i.e. so
    # that AttentionExplainAndScoreMethodId or NeuronExplainAndScoreMethodId differ).
    if os.path.exists(local_cache_dir) and os.path.isdir(local_cache_dir):
        for entry in os.listdir(local_cache_dir):
            candidate_path = os.path.join(local_cache_dir, entry)
            if os.path.isdir(candidate_path):
                datasets.append(candidate_path)
    return datasets
