def predict()

in sparse_autoencoder/explanations.py [0:0]


    def predict(self, tokens: list[str]) -> list[float]:
        predicted_acts = []
        # for each token, traverse the trie beginning from that token and proceeding in reverse order until we match
        # a pattern or are no longer able to traverse.
        for i in range(len(tokens)):
            curr = self.trie
            for j in range(i, -1, -1):
                if tokens[j] not in curr and _ANY_TOKEN not in curr:
                    predicted_acts.append(0)
                    break
                if tokens[j] in curr:
                    curr = curr[tokens[j]]
                else:
                    curr = curr[_ANY_TOKEN]
                if _SALIENCY_KEY in curr:
                    predicted_acts.append(curr[_SALIENCY_KEY])
                    break
                # if we"ve reached the end of the sequence and haven't found a saliency value, append 0.
                elif j == 0:
                    if _START_TOKEN in curr:
                        curr = curr[_START_TOKEN]
                        assert _SALIENCY_KEY in curr
                        predicted_acts.append(curr[_SALIENCY_KEY])
                        break
                    predicted_acts.append(0)
        # We should have appended a value for each token in the sequence.
        assert len(predicted_acts) == len(tokens)
        return predicted_acts