# Few-shot examples for generating and simulating neuron explanations.

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

from neuron_explainer.activations.activations import ActivationRecord
from neuron_explainer.fast_dataclasses import FastDataclass


@dataclass
class Example(FastDataclass):
    activation_records: list[ActivationRecord]
    explanation: str
    first_revealed_activation_indices: list[int]
    """
    For each activation record, the index of the first token for which the activation value in the
    prompt should be an actual number rather than "unknown".

    Examples all start with the activations rendered as "unknown", then transition to revealing
    specific normalized activation values. The goal is to lead the model to predict that activation
    sequences will eventually transition to predicting specific activation values instead of just
    "unknown". This lets us cheat and get predictions of activation values for every token in a
    single round of inference by having the activations in the sequence we're predicting always be
    "unknown" in the prompt: the model will always think that maybe the next token will be a real
    activation.
    """
    token_index_to_score: int | None = None
    """
    If the prompt is used as an example for one-token-at-a-time scoring, this is the index of the
    token to score.
    """


class FewShotExampleSet(Enum):
    """Determines which few-shot examples to use when sampling explanations."""

    ORIGINAL = "original"
    COLANGV2 = "colangv2"
    TEST = "test"

    @classmethod
    def from_string(cls, string: str) -> FewShotExampleSet:
        for example_set in FewShotExampleSet:
            if example_set.value == string:
                return example_set
        raise ValueError(f"Unrecognized example set: {string}")

    def get_examples(self) -> list[Example]:
        """Returns regular examples for use in a few-shot prompt."""
        if self is FewShotExampleSet.ORIGINAL:
            return ORIGINAL_EXAMPLES
        elif self is FewShotExampleSet.COLANGV2:
            return COLANGV2_EXAMPLES
        elif self is FewShotExampleSet.TEST:
            return TEST_EXAMPLES
        else:
            raise ValueError(f"Unhandled example set: {self}")

    def get_single_token_prediction_example(self) -> Example:
        """
        Returns an example suitable for use in a subprompt for predicting a single token's
        normalized activation, for use with the "one token at a time" scoring approach.
        """
        if self is FewShotExampleSet.COLANGV2:
            return COLANGV2_SINGLE_TOKEN_EXAMPLE
        elif self is FewShotExampleSet.TEST:
            return TEST_SINGLE_TOKEN_EXAMPLE
        else:
            raise ValueError(f"Unhandled example set: {self}")


TEST_EXAMPLES = [
    Example(
        activation_records=[
            ActivationRecord(
                tokens=["a", "b", "c"],
                activations=[1.0, 0.0, 0.0],
            ),
            ActivationRecord(
                tokens=["d", "e", "f"],
                activations=[0.0, 1.0, 0.0],
            ),
        ],
        explanation="vowels",
        first_revealed_activation_indices=[0, 1],
    ),
]

TEST_SINGLE_TOKEN_EXAMPLE = Example(
    activation_records=[
        ActivationRecord(
            activations=[0.0, 0.0, 1.0],
            tokens=["g", "h", "i"],
        ),
    ],
    first_revealed_activation_indices=[],
    token_index_to_score=2,
    explanation="test explanation",
)


ORIGINAL_EXAMPLES = [
    Example(
        activation_records=[
            ActivationRecord(
                tokens=[
                    "t",
                    "urt",
                    "ur",
                    "ro",
                    " is",
                    " fab",
                    "ulously",
                    " funny",
                    " and",
                    " over",
                    " the",
                    " top",
                    " as",
                    " a",
                    " '",
                    "very",
                    " sneaky",
                    "'",
                    " but",
                    "ler",
                    " who",
                    " excel",
                    "s",
                    " in",
                    " the",
                    " art",
                    " of",
                    " impossible",
                    " disappearing",
                    "/",
                    "re",
                    "app",
                    "earing",
                    " acts",
                ],
                activations=[
                    -0.71,
                    -1.85,
                    -2.39,
                    -2.58,
                    -1.34,
                    -1.92,
                    -1.69,
                    -0.84,
                    -1.25,
                    -1.75,
                    -1.42,
                    -1.47,
                    -1.51,
                    -0.8,
                    -1.89,
                    -1.56,
                    -1.63,
                    0.44,
                    -1.87,
                    -2.55,
                    -2.09,
                    -1.76,
                    -1.33,
                    -0.88,
                    -1.63,
                    -2.39,
                    -2.63,
                    -0.99,
                    2.83,
                    -1.11,
                    -1.19,
                    -1.33,
                    4.24,
                    -1.51,
                ],
            ),
            ActivationRecord(
                tokens=[
                    "esc",
                    "aping",
                    " the",
                    " studio",
                    " ,",
                    " pic",
                    "col",
                    "i",
                    " is",
                    " warm",
                    "ly",
                    " affecting",
                    " and",
                    " so",
                    " is",
                    " this",
                    " ad",
                    "roit",
                    "ly",
                    " minimalist",
                    " movie",
                    " .",
                ],
                activations=[
                    -0.69,
                    4.12,
                    1.83,
                    -2.28,
                    -0.28,
                    -0.79,
                    -2.2,
                    -2.03,
                    -1.77,
                    -1.71,
                    -2.44,
                    1.6,
                    -1,
                    -0.38,
                    -1.93,
                    -2.09,
                    -1.63,
                    -1.94,
                    -1.82,
                    -1.64,
                    -1.32,
                    -1.92,
                ],
            ),
        ],
        first_revealed_activation_indices=[10, 3],
        explanation="present tense verbs ending in 'ing'",
    ),
    Example(
        activation_records=[
            ActivationRecord(
                tokens=[
                    "as",
                    " sac",
                    "char",
                    "ine",
                    " movies",
                    " go",
                    " ,",
                    " this",
                    " is",
                    " likely",
                    " to",
                    " cause",
                    " massive",
                    " cardiac",
                    " arrest",
                    " if",
                    " taken",
                    " in",
                    " large",
                    " doses",
                    " .",
                ],
                activations=[
                    -0.14,
                    -1.37,
                    -0.68,
                    -2.27,
                    -1.46,
                    -1.11,
                    -0.9,
                    -2.48,
                    -2.07,
                    -3.49,
                    -2.16,
                    -1.79,
                    -0.23,
                    -0.04,
                    4.46,
                    -1.02,
                    -2.26,
                    -2.95,
                    -1.49,
                    -1.46,
                    -0.6,
                ],
            ),
            ActivationRecord(
                tokens=[
                    "shot",
                    " perhaps",
                    " '",
                    "art",
                    "istically",
                    "'",
                    " with",
                    " handheld",
                    " cameras",
                    " and",
                    " apparently",
                    " no",
                    " movie",
                    " lights",
                    " by",
                    " jo",
                    "aquin",
                    " b",
                    "aca",
                    "-",
                    "as",
                    "ay",
                    " ,",
                    " the",
                    " low",
                    "-",
                    "budget",
                    " production",
                    " swings",
                    " annoy",
                    "ingly",
                    " between",
                    " vert",
                    "igo",
                    " and",
                    " opacity",
                    " .",
                ],
                activations=[
                    -0.09,
                    -3.53,
                    -0.72,
                    -2.36,
                    -1.05,
                    -1.12,
                    -2.49,
                    -2.14,
                    -1.98,
                    -1.59,
                    -2.62,
                    -2,
                    -2.73,
                    -2.87,
                    -3.23,
                    -1.11,
                    -2.23,
                    -0.97,
                    -2.28,
                    -2.37,
                    -1.5,
                    -2.81,
                    -1.73,
                    -3.14,
                    -2.61,
                    -1.7,
                    -3.08,
                    -4,
                    -0.71,
                    -2.48,
                    -1.39,
                    -1.96,
                    -1.09,
                    4.37,
                    -0.74,
                    -0.5,
                    -0.62,
                ],
            ),
        ],
        first_revealed_activation_indices=[5, 20],
        explanation="words related to physical medical conditions",
    ),
    Example(
        activation_records=[
            # The sense of togetherness in our town is strong.
            ActivationRecord(
                tokens=[
                    "the",
                    " sense",
                    " of",
                    " together",
                    "ness",
                    " in",
                    " our",
                    " town",
                    " is",
                    " strong",
                    " .",
                ],
                activations=[
                    0,
                    0,
                    0,
                    1,
                    2,
                    0,
                    0.23,
                    0.5,
                    0,
                    0,
                    0,
                ],
            ),
            ActivationRecord(
                tokens=[
                    "a",
                    " buoy",
                    "ant",
                    " romantic",
                    " comedy",
                    " about",
                    " friendship",
                    " ,",
                    " love",
                    " ,",
                    " and",
                    " the",
                    " truth",
                    " that",
                    " we",
                    "'re",
                    " all",
                    " in",
                    " this",
                    " together",
                    " .",
                ],
                activations=[
                    -0.15,
                    -2.33,
                    -1.4,
                    -2.17,
                    -2.53,
                    -0.85,
                    0.23,
                    -1.89,
                    0.09,
                    -0.47,
                    -0.5,
                    -0.58,
                    -0.87,
                    0.22,
                    0.58,
                    1.34,
                    0.98,
                    2.21,
                    2.84,
                    1.7,
                    -0.89,
                ],
            ),
        ],
        first_revealed_activation_indices=[0, 10],
        explanation="phrases related to community",
    ),
]


COLANGV2_EXAMPLES = [
    Example(
        activation_records=[
            ActivationRecord(
                tokens=[
                    "The",
                    " editors",
                    " of",
                    " Bi",
                    "opol",
                    "ym",
                    "ers",
                    " are",
                    " delighted",
                    " to",
                    " present",
                    " the",
                    " ",
                    "201",
                    "8",
                    " Murray",
                    " Goodman",
                    " Memorial",
                    " Prize",
                    " to",
                    " Professor",
                    " David",
                    " N",
                    ".",
                    " Ber",
                    "atan",
                    " in",
                    " recognition",
                    " of",
                    " his",
                    " seminal",
                    " contributions",
                    " to",
                    " bi",
                    "oph",
                    "ysics",
                    " and",
                    " their",
                    " impact",
                    " on",
                    " our",
                    " understanding",
                    " of",
                    " charge",
                    " transport",
                    " in",
                    " biom",
                    "olecules",
                    ".\n\n",
                    "In",
                    "aug",
                    "ur",
                    "ated",
                    " in",
                    " ",
                    "200",
                    "7",
                    " in",
                    " honor",
                    " of",
                    " the",
                    " Bi",
                    "opol",
                    "ym",
                    "ers",
                    " Found",
                    "ing",
                    " Editor",
                    ",",
                    " the",
                    " prize",
                    " is",
                    " awarded",
                    " for",
                    " outstanding",
                    " accomplishments",
                ],
                activations=[
                    0,
                    0.01,
                    0.01,
                    0,
                    0,
                    0,
                    -0.01,
                    0,
                    -0.01,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0.04,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    3.39,
                    0.12,
                    0,
                    -0.01,
                    0,
                    0,
                    0,
                    0,
                    -0,
                    0,
                    -0,
                    0,
                    0,
                    -0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    -0,
                    0,
                    0,
                    -0.01,
                    0,
                    0.41,
                    0,
                    0,
                    0,
                    -0.01,
                    0,
                    0,
                    0,
                    0,
                    0,
                ],
            ),
            # We sometimes exceed the max context size when this is included :(
            # We can uncomment this if we start using an 8k context size.
            # ActivationRecord(
            #     tokens=[
            #         " We",
            #         " are",
            #         " proud",
            #         " of",
            #         " our",
            #         " national",
            #         " achievements",
            #         " in",
            #         " mastering",
            #         " all",
            #         " aspects",
            #         " of",
            #         " the",
            #         " fuel",
            #         " cycle",
            #         ".",
            #         " The",
            #         " current",
            #         " international",
            #         " interest",
            #         " in",
            #         " closing",
            #         " the",
            #         " fuel",
            #         " cycle",
            #         " is",
            #         " a",
            #         " vind",
            #         "ication",
            #         " of",
            #         " Dr",
            #         ".",
            #         " B",
            #         "hab",
            #         "ha",
            #         "’s",
            #         " pioneering",
            #         " vision",
            #         " and",
            #         " genius",
            #     ],
            #     activations=[
            #         -0,
            #         -0,
            #         0,
            #         -0,
            #         -0,
            #         0,
            #         0,
            #         0,
            #         -0,
            #         0,
            #         0,
            #         -0,
            #         0,
            #         -0.01,
            #         0,
            #         0,
            #         -0,
            #         -0,
            #         0,
            #         0,
            #         0,
            #         -0,
            #         -0,
            #         -0.01,
            #         0,
            #         0,
            #         -0,
            #         0,
            #         0,
            #         0,
            #         0,
            #         0,
            #         -0,
            #         0,
            #         0,
            #         0,
            #         2.15,
            #         0,
            #         0,
            #         0.03,
            #     ],
            # ),
        ],
        first_revealed_activation_indices=[7],  # , 19],
        explanation="language related to something being groundbreaking",
    ),
    Example(
        activation_records=[
            ActivationRecord(
                tokens=[
                    '{"',
                    "widget",
                    "Class",
                    '":"',
                    "Variant",
                    "Matrix",
                    "Widget",
                    '","',
                    "back",
                    "order",
                    "Message",
                    '":"',
                    "Back",
                    "ordered",
                    '","',
                    "back",
                    "order",
                    "Message",
                    "Single",
                    "Variant",
                    '":"',
                    "This",
                    " item",
                    " is",
                    " back",
                    "ordered",
                    '.","',
                    "ordered",
                    "Selection",
                    '":',
                    "true",
                    ',"',
                    "product",
                    "Variant",
                    "Id",
                    '":',
                    "0",
                    ',"',
                    "variant",
                    "Id",
                    "Field",
                    '":"',
                    "product",
                    "196",
                    "39",
                    "_V",
                    "ariant",
                    "Id",
                    '","',
                    "back",
                    "order",
                    "To",
                    "Message",
                    "Single",
                    "Variant",
                    '":"',
                    "This",
                    " item",
                    " is",
                    " back",
                    "ordered",
                    " and",
                    " is",
                    " expected",
                    " by",
                    " {",
                    "0",
                    "}.",
                    '","',
                    "low",
                    "Price",
                    '":',
                    "999",
                    "9",
                    ".",
                    "0",
                    ',"',
                    "attribute",
                    "Indexes",
                    '":[',
                    '],"',
                    "productId",
                    '":',
                    "196",
                    "39",
                    ',"',
                    "price",
                    "V",
                    "ariance",
                    '":',
                    "true",
                    ',"',
                ],
                activations=[
                    -0.03,
                    0,
                    0,
                    0,
                    4.2,
                    0,
                    0,
                    0,
                    0,
                    -0,
                    0,
                    -0,
                    0,
                    0,
                    0,
                    0,
                    -0.01,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    -0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    -0.03,
                    0,
                    0,
                    0,
                    0,
                    -0.02,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    -0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    -0,
                    -0,
                    0,
                    0,
                    0,
                    0.01,
                    -0.01,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    -0.02,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1.24,
                    0,
                    0,
                    0,
                ],
            ),
            ActivationRecord(
                tokens=[
                    "A",
                    " regular",
                    " look",
                    " at",
                    " the",
                    " ups",
                    " and",
                    " downs",
                    " of",
                    " variant",
                    " covers",
                    " in",
                    " the",
                    " comics",
                    " industry",
                    "…\n\n",
                    "Here",
                    " are",
                    " the",
                    " Lego",
                    " variant",
                    " sketch",
                    " covers",
                    " by",
                    " Leon",
                    "el",
                    " Cast",
                    "ell",
                    "ani",
                    " for",
                    " a",
                    " variety",
                    " of",
                    " Marvel",
                    " titles",
                    ",",
                ],
                activations=[
                    0,
                    -0.01,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    6.52,
                    0,
                    0,
                    0,
                    -0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1.62,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    -0,
                    0,
                    0,
                    0,
                    -0,
                    0,
                ],
            ),
        ],
        first_revealed_activation_indices=[2, 8],
        explanation="the word “variant” and other words with the same ”vari” root",
    ),
]

COLANGV2_SINGLE_TOKEN_EXAMPLE = Example(
    activation_records=[
        ActivationRecord(
            tokens=[
                "B",
                "10",
                " ",
                "111",
                " MON",
                "DAY",
                ",",
                " F",
                "EB",
                "RU",
                "ARY",
                " ",
                "11",
                ",",
                " ",
                "201",
                "9",
                " DON",
                "ATE",
                "fake higher scoring token",  # See below.
            ],
            activations=[
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0.37,
                # This fake activation makes the previous token's activation normalize to 8, which
                # might help address overconfidence in "10" activations for the one-token-at-a-time
                # scoring prompt. This value and the associated token don't actually appear anywhere
                # in the prompt.
                0.45,
            ],
        ),
    ],
    first_revealed_activation_indices=[],
    token_index_to_score=18,
    explanation="instances of the token 'ate' as part of another word",
)


@dataclass
class AttentionSimulationExample(FastDataclass):
    token_pair_example_index: int
    token_pair_coordinates: tuple[int, int]
    label: int


@dataclass
class AttentionTokenPairExample(FastDataclass):
    tokens: list[str]
    token_pair_coordinates: list[tuple[int, int]]


@dataclass
class AttentionHeadFewShotExample(FastDataclass):
    token_pair_examples: list[AttentionTokenPairExample]
    explanation: str
    simulation_examples: list[AttentionSimulationExample] | None = None


ATTENTION_HEAD_FEW_SHOT_EXAMPLES: list[AttentionHeadFewShotExample] = [
    # gpt2-xl, layer 1, head 1
    AttentionHeadFewShotExample(
        token_pair_examples=[
            AttentionTokenPairExample(
                tokens=[
                    " dreams",
                    " of",
                    " a",
                    " future",
                    " like",
                    " her",
                    " biggest",
                    " idol",
                    ",",
                    " who",
                    " was",
                    " also",
                    " born",
                    " visually",
                    " impaired",
                    ".",
                    "\n",
                    "\n",
                    '"',
                    "My",
                    " ultimate",
                    " dream",
                    " would",
                    " be",
                    " to",
                    " sing",
                    " at",
                    " Carol",
                    "s",
                    " [",
                    "by",
                    " Candle",
                    "light",
                    "]",
                    " and",
                    " to",
                    " become",
                    " a",
                    " famous",
                    " musician",
                    " like",
                    " Andrea",
                    " Bo",
                    "cell",
                    "i",
                    " ...",
                    " and",
                    " to",
                    " show",
                    " people",
                    " that",
                    " if",
                    " you",
                    " have",
                    " a",
                    " disability",
                    " it",
                    " doesn",
                    "'t",
                    " matter",
                    ',"',
                    " she",
                    " said",
                    ".",
                ],
                # 45 = "attended from", 33 = "attended to"
                token_pair_coordinates=[(45, 33)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    "omes",
                    " Ever",
                    " Sequ",
                    "enced",
                    "]",
                    "\n",
                    "\n",
                    "One",
                    " mystery",
                    " of",
                    " cat",
                    " development",
                    " is",
                    " how",
                    " cats",
                    " have",
                    " come",
                    " to",
                    " have",
                    " such",
                    " varied",
                    " coats",
                    ",",
                    " from",
                    " solid",
                    " colours",
                    " to",
                    ' "',
                    "mac",
                    "ke",
                    "rel",
                    '"',
                    " tab",
                    "by",
                    " patterns",
                    " of",
                    " thin",
                    " vertical",
                    " stripes",
                    ".",
                    " The",
                    " researchers",
                    " were",
                    " particularly",
                    " interested",
                    " in",
                    " what",
                    " turns",
                    " the",
                    " mac",
                    "ke",
                    "rel",
                    " pattern",
                    " into",
                    " a",
                    ' "',
                    "bl",
                    "ot",
                    "ched",
                    '"',
                    " tab",
                    "by",
                    " pattern",
                    ",",
                ],
                token_pair_coordinates=[(5, 4)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    ",",
                    " 6",
                    ",",
                    " 8",
                    ",",
                    " 4",
                    "]",
                    "':",
                    "rb",
                    ".",
                    "sort",
                    ".",
                    "slice",
                    "(",
                    "1",
                    ",",
                    "2",
                    ");",
                    " #",
                    " More",
                    " advanced",
                    ",",
                    " this",
                    " is",
                    " Ruby",
                    "'s",
                    " map",
                    " and",
                    " each",
                    "_",
                    "with",
                    "_",
                    "index",
                    " #",
                    " This",
                    " shows",
                    " the",
                    " :",
                    "rb",
                    " post",
                    "fix",
                    "-",
                    "operator",
                    " sugar",
                    " instead",
                    " of",
                    " EV",
                    "AL",
                    ' "[',
                    "1",
                    ",",
                    "2",
                    ",",
                    "3",
                    ",",
                    "4",
                    "]",
                    '":',
                    "rb",
                    " .",
                    "map",
                    "(",
                    "->",
                    " $",
                ],
                token_pair_coordinates=[(7, 6)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " him",
                    " a",
                    " W",
                    "N",
                    " [",
                    "white",
                    " nationalist",
                    "]",
                    " until",
                    " there",
                    " is",
                    " an",
                    " indication",
                    " as",
                    " such",
                    "...",
                    " The",
                    " fact",
                    " that",
                    " he",
                    " targeted",
                    " a",
                    " church",
                    " gives",
                    " me",
                    " an",
                    " ink",
                    "ling",
                    " that",
                    " it",
                    " was",
                    " religion",
                    "-",
                    "related",
                    ',"',
                    " wrote",
                    " White",
                    "Virgin",
                    "ian",
                    ".",
                    "\n",
                    "\n",
                    '"',
                    "Yep",
                    ",",
                    " bad",
                    " news",
                    " for",
                    " gun",
                    " rights",
                    " advocates",
                    " as",
                    " well",
                    ',"',
                    " wrote",
                    " math",
                    "the",
                    "ory",
                    "l",
                    "over",
                    "2008",
                    ".",
                    ' "',
                    "Another",
                ],
                token_pair_coordinates=[(15, 7)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    "23",
                    "]",
                    "\n",
                    "\n",
                    "While",
                    " preparing",
                    " to",
                    " take",
                    " the",
                    " fight",
                    " to",
                    " Prim",
                    "ord",
                    "us",
                    ",",
                    " B",
                    "alth",
                    "azar",
                    " learned",
                    " about",
                    " T",
                    "aim",
                    "i",
                    "'s",
                    " machine",
                    " and",
                    " how",
                    " it",
                    " could",
                    " supposedly",
                    " kill",
                    " two",
                    " Elder",
                    " Dragons",
                    " with",
                    " a",
                    " single",
                    " blow",
                    ",",
                    " which",
                    " p",
                    "iqu",
                    "ed",
                    " his",
                    " interest",
                    ".",
                    " This",
                    " piece",
                    " of",
                    " news",
                    ",",
                    " as",
                    " well",
                    " as",
                    " Mar",
                    "j",
                    "ory",
                    "'s",
                    " sudden",
                    " departure",
                    " from",
                    " his",
                    " side",
                    " which",
                ],
                token_pair_coordinates=[(3, 1)],
            ),
        ],
        explanation="attends to the latest closing square bracket from arbitrary subsequent tokens",
    ),
    # gpt2-xl, layer 2, head 8
    AttentionHeadFewShotExample(
        simulation_examples=[
            AttentionSimulationExample(
                token_pair_example_index=0,
                token_pair_coordinates=(63, 15),
                label=0,
            ),
            AttentionSimulationExample(
                token_pair_example_index=0,
                token_pair_coordinates=(50, 15),
                label=1,
            ),
        ],
        token_pair_examples=[
            AttentionTokenPairExample(
                tokens=[
                    " he",
                    " said",
                    ".",
                    ' "',
                    "Coming",
                    " off",
                    " winning",
                    " the",
                    " year",
                    " before",
                    ",",
                    " I",
                    " love",
                    " playing",
                    " links",
                    " golf",
                    ",",
                    " and",
                    " I",
                    " love",
                    " playing",
                    " the",
                    " week",
                    " before",
                    " a",
                    " major",
                    ".",
                    " It",
                    " was",
                    " tough",
                    " to",
                    " miss",
                    " it",
                    ".",
                    " I",
                    "'m",
                    " just",
                    " glad",
                    " to",
                    " be",
                    " back",
                    '."',
                    "\n",
                    "\n",
                    "F",
                    "owler",
                    " out",
                    "played",
                    " his",
                    " partners",
                    " Rory",
                    " Mc",
                    "Il",
                    "roy",
                    " (",
                    "74",
                    ")",
                    " and",
                    " Hen",
                    "rik",
                    " St",
                    "enson",
                    " (",
                    "72",
                ],
                token_pair_coordinates=[(50, 15)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " Club",
                    ":",
                    "\n",
                    "\n",
                    "1",
                    ".",
                    " World",
                    " renowned",
                    " golf",
                    " course",
                    "\n",
                    "\n",
                    "2",
                    ".",
                    " Vern",
                    " Mor",
                    "com",
                    " designed",
                    " golf",
                    " course",
                    "\n",
                    "\n",
                    "3",
                    ".",
                    " Great",
                    " family",
                    " holiday",
                    " destination",
                    "\n",
                    "\n",
                    "4",
                    ".",
                    " Play",
                    " amid",
                    " our",
                    " resident",
                    " Eastern",
                    " Grey",
                    " k",
                    "ang",
                    "aroo",
                    " population",
                    "\n",
                    "\n",
                    "5",
                    ".",
                    " Terr",
                    "ific",
                    " friendly",
                    " staff",
                    "\n",
                    "\n",
                    "6",
                    ".",
                    " Natural",
                    " pictures",
                    "que",
                    " bush",
                    " setting",
                    "\n",
                    "\n",
                    "7",
                    ".",
                    " L",
                ],
                token_pair_coordinates=[(9, 8)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    "615",
                    " rpm",
                    " on",
                    " average",
                    ").",
                    " As",
                    " a",
                    " result",
                    ",",
                    " each",
                    " player",
                    " was",
                    " hitting",
                    " longer",
                    " drives",
                    " on",
                    " their",
                    " best",
                    " shots",
                    ",",
                    " while",
                    " achieving",
                    " a",
                    " stra",
                    "ighter",
                    " ball",
                    " flight",
                    " that",
                    " was",
                    " less",
                    " affected",
                    " by",
                    " wind",
                    ".",
                    "\n",
                    "\n",
                    "Every",
                    " Golf",
                    "WR",
                    "X",
                    " Member",
                    " gained",
                    " yard",
                    "age",
                    " with",
                    " a",
                    " new",
                    " Taylor",
                    "Made",
                    " driver",
                    ";",
                    " the",
                    " largest",
                    " distance",
                    " gain",
                    " was",
                    " an",
                    " impressive",
                    " +",
                    "10",
                    ".",
                    "1",
                    " yards",
                    ",",
                ],
                token_pair_coordinates=[(47, 37)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " of",
                    " being",
                    "?",
                    " Well",
                    ",",
                    " having",
                    " perfected",
                    " the",
                    " art",
                    " of",
                    " swimming",
                    ",",
                    " Phelps",
                    " has",
                    " moved",
                    " on",
                    " to",
                    " another",
                    " cherished",
                    " summer",
                    " past",
                    "ime",
                    " –",
                    " golf",
                    ".",
                    " Here",
                    " he",
                    " is",
                    " participating",
                    " in",
                    " the",
                    " Dun",
                    "hill",
                    " Links",
                    " Championship",
                    " at",
                    " Kings",
                    "b",
                    "arn",
                    "s",
                    " in",
                    " Scotland",
                    " today",
                    ".",
                    " The",
                    " greens",
                    " over",
                    " there",
                    " are",
                    " really",
                    " big",
                    ",",
                    " so",
                    " the",
                    " opportunity",
                    " for",
                    " 50",
                    "-",
                    "yard",
                    " put",
                    "ts",
                    " exist",
                    ".",
                    " Of",
                ],
                token_pair_coordinates=[(45, 23)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    "OTUS",
                    " is",
                    " getting",
                    " to",
                    " see",
                    " aliens",
                    ".",
                    "\n",
                    "\n",
                    "RELATED",
                    ":",
                    " Barack",
                    " Obama",
                    " joins",
                    " second",
                    " D",
                    ".",
                    "C",
                    ".-",
                    "area",
                    " golf",
                    " club",
                    "\n",
                    "\n",
                    '"',
                    "He",
                    " goes",
                    ",",
                    " '",
                    "they",
                    "'re",
                    " freaking",
                    " crazy",
                    " looking",
                    ".'",
                    " And",
                    " then",
                    " he",
                    " walks",
                    " up",
                    ",",
                    " makes",
                    " his",
                    " put",
                    "t",
                    ",",
                    " turns",
                    " back",
                    ",",
                    " walks",
                    " off",
                    " the",
                    " green",
                    ",",
                    " leaves",
                    " it",
                    " at",
                    " that",
                    " and",
                    " gives",
                    " me",
                    " a",
                    " wink",
                    ',"',
                ],
                token_pair_coordinates=[(52, 20)],
            ),
        ],
        explanation='attends to the token "golf" from golf-related tokens',
    ),
    # gpt2-xl, layer 1, head 10
    AttentionHeadFewShotExample(
        simulation_examples=[
            AttentionSimulationExample(
                token_pair_example_index=0,
                token_pair_coordinates=(37, 36),
                label=0,
            ),
            AttentionSimulationExample(
                token_pair_example_index=0,
                token_pair_coordinates=(14, 12),
                label=1,
            ),
        ],
        token_pair_examples=[
            AttentionTokenPairExample(
                tokens=[
                    " security",
                    " by",
                    " requiring",
                    " the",
                    " user",
                    " to",
                    " enter",
                    " a",
                    " numeric",
                    " code",
                    " sent",
                    " to",
                    " his",
                    " or",
                    " her",
                    " cellphone",
                    " in",
                    " addition",
                    " to",
                    " a",
                    " password",
                    ".",
                    " A",
                    " lot",
                    " of",
                    " websites",
                    " have",
                    " offered",
                    " this",
                    " feature",
                    " for",
                    " years",
                    ",",
                    " but",
                    " Int",
                    "uit",
                    " just",
                    " made",
                    " it",
                    " widely",
                    " available",
                    " earlier",
                    " this",
                    " year",
                    ".",
                    "\n",
                    "\n",
                    '"',
                    "When",
                    " you",
                    " give",
                    " your",
                    " most",
                    " sensitive",
                    " data",
                    " and",
                    " that",
                    " of",
                    " your",
                    " family",
                    " to",
                    " a",
                    " company",
                    ",",
                ],
                token_pair_coordinates=[(14, 12)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " 3",
                    " months",
                    ",",
                    " they",
                    " separated",
                    " the",
                    " men",
                    " and",
                    " women",
                    " here",
                    ".",
                    " I",
                    " don",
                    "'t",
                    " know",
                    " where",
                    " they",
                    " took",
                    " the",
                    " men",
                    " and",
                    " the",
                    " children",
                    ",",
                    " but",
                    " they",
                    " took",
                    " us",
                    " women",
                    " to",
                    " Syria",
                    ".",
                    " They",
                    " kept",
                    " us",
                    " in",
                    " an",
                    " underground",
                    " prison",
                    ".",
                    " My",
                    " only",
                    " wish",
                    " is",
                    " that",
                    " my",
                    " children",
                    " and",
                    " husband",
                    " escape",
                    " ISIS",
                    ".",
                    " They",
                    " brought",
                    " us",
                    " here",
                    " from",
                    " Raqqa",
                    ",",
                    " my",
                    " sisters",
                    " from",
                    " the",
                    " PKK",
                ],
                token_pair_coordinates=[(8, 6)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " an",
                    " emphasis",
                    " on",
                    " the",
                    " pursuit",
                    " of",
                    " power",
                    " despite",
                    " interpersonal",
                    " costs",
                    '."',
                    "\n",
                    "\n",
                    "The",
                    " study",
                    ",",
                    " which",
                    " involved",
                    " over",
                    " 600",
                    " young",
                    " men",
                    " and",
                    " women",
                    ",",
                    " makes",
                    " a",
                    " strong",
                    " case",
                    " for",
                    " assessing",
                    " such",
                    " traits",
                    " as",
                    ' "',
                    "r",
                    "uth",
                    "less",
                    " ambition",
                    ',"',
                    ' "',
                    "dis",
                    "comfort",
                    " with",
                    " leadership",
                    '"',
                    " and",
                    ' "',
                    "hub",
                    "rist",
                    "ic",
                    " pride",
                    '"',
                    " to",
                    " understand",
                    " psychopath",
                    "ologies",
                    ".",
                    "\n",
                    "\n",
                    "The",
                    " researchers",
                    " looked",
                    " at",
                ],
                token_pair_coordinates=[(23, 21)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " 4",
                    " hours",
                    ".",
                    " These",
                    " results",
                    ",",
                    " differ",
                    " between",
                    " men",
                    " and",
                    " women",
                    ",",
                    " however",
                    ".",
                    " We",
                    " can",
                    " see",
                    " that",
                    " although",
                    " both",
                    " groups",
                    " have",
                    " a",
                    " large",
                    " cluster",
                    " of",
                    " people",
                    " at",
                    " exactly",
                    " 40",
                    " hours",
                    " per",
                    " week",
                    ",",
                    " there",
                    " are",
                    " more",
                    " men",
                    " reporting",
                    " hours",
                    " above",
                    " 40",
                    ",",
                    " whereas",
                    " there",
                    " are",
                    " more",
                    " women",
                    " reporting",
                    " hours",
                    " below",
                    " 40",
                    ".",
                    " Result",
                    " 3",
                    ":",
                    " Male",
                    " Hours",
                    " Work",
                    "ed",
                    " [",
                    "Info",
                    "]",
                    " Owner",
                ],
                token_pair_coordinates=[(10, 8)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " they",
                    " were",
                    " perceived",
                    " as",
                    " more",
                    " emotional",
                    ",",
                    " which",
                    " made",
                    " participants",
                    " more",
                    " confident",
                    " in",
                    " their",
                    " own",
                    " opinion",
                    '."',
                    "\n",
                    "\n",
                    "Ms",
                    " Sal",
                    "erno",
                    " said",
                    " both",
                    " men",
                    " and",
                    " women",
                    " reacted",
                    " in",
                    " the",
                    " same",
                    " way",
                    " to",
                    " women",
                    " expressing",
                    " themselves",
                    " angrily",
                    ".",
                    "\n",
                    "\n",
                    '"',
                    "Particip",
                    "ants",
                    " confidence",
                    " in",
                    " their",
                    " own",
                    " verdict",
                    " dropped",
                    " significantly",
                    " after",
                    " male",
                    " hold",
                    "outs",
                    " expressed",
                    " anger",
                    ',"',
                    " the",
                    " paper",
                    "'s",
                    " findings",
                    " stated",
                    ".",
                    "\n",
                ],
                token_pair_coordinates=[(26, 24)],
            ),
        ],
        explanation="attends to male-related tokens from paired female-related tokens",
    ),
    # gpt2-xl, layer 1, head 3
    AttentionHeadFewShotExample(
        token_pair_examples=[
            AttentionTokenPairExample(
                tokens=[
                    "\n",
                    "********************************",
                    "************",
                    "***",
                    "\n",
                    "\n",
                    "V",
                    "iet",
                    "namese",
                    " Ministry",
                    " of",
                    " Foreign",
                    " Affairs",
                    " spokesperson",
                    " Le",
                    " Hai",
                    " Bin",
                    "h",
                    " is",
                    " seen",
                    " in",
                    " this",
                    " file",
                    " photo",
                    ".",
                    " .",
                    " Tu",
                    "oi",
                    " Tre",
                    "\n",
                    "\n",
                    "The",
                    " Ministry",
                    " of",
                    " Foreign",
                    " Affairs",
                    " has",
                    " ordered",
                    " a",
                    " thorough",
                    " investigation",
                    " into",
                    " a",
                    " case",
                    " in",
                    " which",
                    " a",
                    " Vietnamese",
                    " fisherman",
                    " was",
                    " shot",
                    " dead",
                    " on",
                    " his",
                    " boat",
                    " in",
                    " Vietnam",
                    "'s",
                    " Tru",
                    "ong",
                    " Sa",
                    " (",
                    "Spr",
                    "at",
                ],
                token_pair_coordinates=[(1, 1)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " J",
                    "okin",
                    "en",
                    " tells",
                    " a",
                    " much",
                    " different",
                    " story",
                    ".",
                    " He",
                    " almost",
                    " sounded",
                    " like",
                    " a",
                    " pitch",
                    "man",
                    ".",
                    "\n",
                    "\n",
                    '"',
                    "All",
                    " the",
                    " staff",
                    ",",
                    " team",
                    " service",
                    " guys",
                    ",",
                    " all",
                    " the",
                    " trainers",
                    ",",
                    " they",
                    "'re",
                    " unbelievable",
                    " guys",
                    ',"',
                    " said",
                    " J",
                    "okin",
                    "en",
                    ".",
                    ' "',
                    "It",
                    "'s",
                    " not",
                    " just",
                    " the",
                    " players",
                    ",",
                    " it",
                    "'s",
                    " the",
                    " staff",
                    " around",
                    " the",
                    " team",
                    ".",
                    " I",
                    " feel",
                    " really",
                    " bad",
                    " for",
                    " them",
                ],
                token_pair_coordinates=[(1, 1)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " a",
                    " Pv",
                    "E",
                    " game",
                    ",",
                    " we",
                    " probably",
                    " would",
                    " use",
                    " it",
                    " but",
                    " based",
                    " on",
                    " the",
                    " tests",
                    " we",
                    "'ve",
                    " run",
                    " on",
                    " it",
                    ",",
                    " that",
                    " wouldn",
                    "'t",
                    " be",
                    " our",
                    " first",
                    " choice",
                    " for",
                    " a",
                    " live",
                    " R",
                    "v",
                    "R",
                    " game",
                    ".",
                    " Now",
                    ",",
                    " could",
                    " we",
                    " use",
                    " it",
                    " for",
                    " prototyp",
                    "ing",
                    "?",
                    " Yep",
                    ",",
                    " we",
                    " are",
                    " already",
                    " doing",
                    " that",
                    ".",
                    " Second",
                    ",",
                    " as",
                    " to",
                    " other",
                    " engines",
                    " there",
                    " are",
                    " both",
                    " financial",
                ],
                token_pair_coordinates=[(1, 1)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    "-",
                    "tun",
                    "er",
                    " is",
                    " also",
                    " custom",
                    "isable",
                    " for",
                    " hassle",
                    "-",
                    "free",
                    " experimentation",
                    ".",
                    " The",
                    " St",
                    "rix",
                    " X",
                    "399",
                    "-",
                    "E",
                    " Gaming",
                    " takes",
                    " up",
                    " to",
                    " three",
                    " double",
                    "-",
                    "wide",
                    " cards",
                    " in",
                    " SLI",
                    " or",
                    " Cross",
                    "Fire",
                    "X",
                    ".",
                    " Primary",
                    " graphics",
                    " slots",
                    " are",
                    " protected",
                    " by",
                    " Safe",
                    "Slot",
                    " from",
                    " damages",
                    " that",
                    " heavy",
                    " GPU",
                    " cool",
                    "ers",
                    " can",
                    " potentially",
                    " cause",
                    ".",
                    "\n",
                    "\n",
                    "Personal",
                    "ised",
                    " RGB",
                    " lighting",
                    " is",
                    " made",
                    " possible",
                ],
                token_pair_coordinates=[(1, 1)],
            ),
            AttentionTokenPairExample(
                tokens=[
                    " to",
                    " abolish",
                    " such",
                    " a",
                    " complex",
                    "?",
                    " Are",
                    " there",
                    " ways",
                    " ve",
                    "gans",
                    " can",
                    " eat",
                    " more",
                    " sustain",
                    "ably",
                    "?",
                    " What",
                    " are",
                    " some",
                    " of",
                    " the",
                    " health",
                    " challenges",
                    " for",
                    " new",
                    " ve",
                    "gans",
                    ",",
                    " and",
                    " how",
                    " can",
                    " we",
                    " raise",
                    " awareness",
                    " of",
                    " these",
                    " issues",
                    " so",
                    " that",
                    ",",
                    " for",
                    " instance",
                    ",",
                    " medical",
                    " professionals",
                    " are",
                    " more",
                    " supportive",
                    " of",
                    " vegan",
                    "ism",
                    "?",
                    "\n",
                    "\n",
                    "Moreover",
                    ",",
                    " it",
                    " is",
                    " essential",
                    " that",
                    " ve",
                    "gans",
                    " differentiate",
                ],
                token_pair_coordinates=[(1, 1)],
            ),
        ],
        explanation="attends from the second token in the sequence to the second token in the sequence",
    ),
]
