def define_read_routes()

in neuron_explainer/activation_server/read_routes.py [0:0]


def define_read_routes(app: FastAPI) -> None:
    @app.post(
        "/existing_explanations", response_model=list[AttributedScoredExplanation], tags=["read"]
    )
    async def existing_explanations(
        request: ExistingExplanationsRequest,
    ) -> list[AttributedScoredExplanation]:
        def convert_scored_explanation(
            scored_explanation: ScoredExplanation, explanation_dataset: str
        ) -> AttributedScoredExplanation:
            return AttributedScoredExplanation(
                explanation=scored_explanation.explanation,
                score=scored_explanation.get_preferred_score(),
                dataset_name=explanation_dataset.split("/")[-1],
            )

        async def load_and_convert_explanations(
            explanation_dataset: str,
        ) -> list[AttributedScoredExplanation]:
            neuron_simulation_results = await load_neuron_explanations_async(
                explanation_dataset, request.layer_index, request.activation_index
            )
            if neuron_simulation_results is None:
                return []
            else:
                return [
                    convert_scored_explanation(scored_explanation, explanation_dataset)
                    for scored_explanation in neuron_simulation_results.scored_explanations
                    if scored_explanation.explanation is not None
                ]

        if not ((len(request.explanation_datasets) > 0) ^ (request.neuron_dataset is not None)):
            raise HTTPException(
                status_code=400,
                detail="Exactly one of explanation_datasets and neuron_dataset must be specified.",
            )

        if len(request.explanation_datasets) > 0:
            explanation_datasets = request.explanation_datasets
        else:
            assert request.neuron_dataset is not None  # Redundant assert; mypy needs this.
            neuron_dataset = resolve_neuron_dataset(request.neuron_dataset, request.dst)
            explanation_datasets = await get_all_explanation_datasets(neuron_dataset)

        tasks = [load_and_convert_explanations(dataset) for dataset in explanation_datasets]
        scored_explanation_lists = await asyncio.gather(*tasks)
        # Flatten the list of lists.
        return [item for sublist in scored_explanation_lists for item in sublist]

    @app.post("/neuron_record", response_model=NeuronRecordResponse, tags=["read"])
    async def neuron_record(request: NodeIdAndDatasets) -> NeuronRecordResponse:
        dataset_path, neuron_record = await load_neuron_from_datasets(request)
        top_activations, random_sample = convert_activation_records_to_token_and_activation_lists(
            [
                neuron_record.most_positive_activation_records,
                neuron_record.random_sample,
            ]
        )
        return NeuronRecordResponse(
            dataset=dataset_path,
            max_activation=neuron_record.max_activation,
            top_activations=top_activations,
            random_sample=random_sample,
        )

    @app.post("/attention_head_record", response_model=AttentionHeadRecordResponse, tags=["read"])
    async def attention_head_record(request: NodeIdAndDatasets) -> AttentionHeadRecordResponse:
        dataset_path, neuron_record = await load_neuron_from_datasets(request)
        (
            most_positive_token_sequences,
            random_token_sequences,
        ) = convert_activation_records_to_token_and_attention_activations_lists(
            [
                neuron_record.most_positive_activation_records,
                neuron_record.random_sample,
            ]
        )

        return AttentionHeadRecordResponse(
            dataset=dataset_path,
            max_attention_activation=neuron_record.max_activation,
            most_positive_token_sequences=most_positive_token_sequences,
            random_token_sequences=random_token_sequences,
        )

    @app.post(
        "/neuron_datasets_metadata", response_model=list[NeuronDatasetMetadata], tags=["read"]
    )
    def neuron_datasets_metadata() -> list[NeuronDatasetMetadata]:
        return get_all_neuron_dataset_metadata()