in bench/generation/metrics/perplexity.py [0:0]
def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count):
"""
Processes each batch of tokens.
Parameters
----------
i : int
The batch index.
n_ctx : int
The context size.
n_batch : int
The batch size.
tokens : torch.Tensor
The tokenized text.
nll : float
The current negative log likelihood.
count : int
The current count of processed tokens.
Returns
-------
float
The updated negative log likelihood.
int
The updated count of processed tokens.
"""
start = i * n_ctx
end = start + n_ctx
num_batches = (n_ctx + n_batch - 1) // n_batch
logits = []
for j in range(num_batches):
batch_start = start + j * n_batch
batch_size = min(end - batch_start, n_batch)
token_org = tokens[0][batch_start].item()
if j == 0:
# Replace the first token with the BOS token
tokens[0][batch_start] = self._tokenizer.bos_token_id
# Compute the logits for the current batch of tokens
batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size)
tokens[0][batch_start] = token_org
logits.append(batch_logits)
# We rely on the fact that attention in the forward pass only looks at previous
# tokens here, so the logits returned for each token are an accurate representation
# of what the model would have predicted at that point.
#
# Example, we have a context window of 512, we will compute perplexity for each of the
# last 256 tokens. Then, we split the input up into context window size chunks to
# process the entire prompt.
for j in range(min(512, n_ctx // 2), n_ctx - 1):
tok_logits = logits[0][0][j].cpu().numpy()
# Compute the probability of the next token
prob = self.softmax(tok_logits)[tokens[0][start + j + 1]]
# Update the negative log likelihood and the count of processed tokens
nll += -np.log(prob, where=prob > 0)
count += 1
return nll, count