in variant-prediction/predict.py [0:0]
def compute_pppl(row, sequence, model, alphabet, offset_idx):
wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"
# modify the sequence
sequence = sequence[:idx] + mt + sequence[(idx + 1) :]
# encode the sequence
data = [
("protein1", sequence),
]
batch_converter = alphabet.get_batch_converter()
batch_labels, batch_strs, batch_tokens = batch_converter(data)
wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)
# compute probabilities at each position
log_probs = []
for i in range(1, len(sequence) - 1):
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, i] = alphabet.mask_idx
with torch.no_grad():
token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1)
log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item()) # vocab size
return sum(log_probs)