in sparse_autoencoder/explanations.py [0:0]
def predict(self, tokens: list[str]) -> list[float]:
predicted_acts = []
# for each token, traverse the trie beginning from that token and proceeding in reverse order until we match
# a pattern or are no longer able to traverse.
for i in range(len(tokens)):
curr = self.trie
for j in range(i, -1, -1):
if tokens[j] not in curr and _ANY_TOKEN not in curr:
predicted_acts.append(0)
break
if tokens[j] in curr:
curr = curr[tokens[j]]
else:
curr = curr[_ANY_TOKEN]
if _SALIENCY_KEY in curr:
predicted_acts.append(curr[_SALIENCY_KEY])
break
# if we"ve reached the end of the sequence and haven't found a saliency value, append 0.
elif j == 0:
if _START_TOKEN in curr:
curr = curr[_START_TOKEN]
assert _SALIENCY_KEY in curr
predicted_acts.append(curr[_SALIENCY_KEY])
break
predicted_acts.append(0)
# We should have appended a value for each token in the sequence.
assert len(predicted_acts) == len(tokens)
return predicted_acts