def first_position_fn()

in sparse_autoencoder/explanations.py [0:0]


    def first_position_fn(tokens: list[list[float]]) -> list[list[float]]:
        return [[1.0] + [0.0] * (len(toks) - 1) for toks in tokens]