in variant-prediction/predict.py [0:0]
def main(args):
# Load the deep mutational scan
df = pd.read_csv(args.dms_input)
# inference for each model
for model_location in args.model_location:
model, alphabet = pretrained.load_model_and_alphabet(model_location)
model.eval()
if torch.cuda.is_available() and not args.nogpu:
model = model.cuda()
print("Transferred model to GPU")
batch_converter = alphabet.get_batch_converter()
if isinstance(model, MSATransformer):
data = [read_msa(args.msa_path, args.msa_samples)]
assert (
args.scoring_strategy == "masked-marginals"
), "MSA Transformer only supports masked marginal strategy"
batch_labels, batch_strs, batch_tokens = batch_converter(data)
all_token_probs = []
for i in tqdm(range(batch_tokens.size(2))):
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, 0, i] = alphabet.mask_idx # mask out first sequence
with torch.no_grad():
token_probs = torch.log_softmax(
model(batch_tokens_masked.cuda())["logits"], dim=-1
)
all_token_probs.append(token_probs[:, 0, i]) # vocab size
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
df[model_location] = df.apply(
lambda row: label_row(
row[args.mutation_col], args.sequence, token_probs, alphabet, args.offset_idx
),
axis=1,
)
else:
data = [
("protein1", args.sequence),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
if args.scoring_strategy == "wt-marginals":
with torch.no_grad():
token_probs = torch.log_softmax(model(batch_tokens.cuda())["logits"], dim=-1)
df[model_location] = df.apply(
lambda row: label_row(
row[args.mutation_col],
args.sequence,
token_probs,
alphabet,
args.offset_idx,
),
axis=1,
)
elif args.scoring_strategy == "masked-marginals":
all_token_probs = []
for i in tqdm(range(batch_tokens.size(1))):
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, i] = alphabet.mask_idx
with torch.no_grad():
token_probs = torch.log_softmax(
model(batch_tokens_masked.cuda())["logits"], dim=-1
)
all_token_probs.append(token_probs[:, i]) # vocab size
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
df[model_location] = df.apply(
lambda row: label_row(
row[args.mutation_col],
args.sequence,
token_probs,
alphabet,
args.offset_idx,
),
axis=1,
)
elif args.scoring_strategy == "pseudo-ppl":
tqdm.pandas()
df[model_location] = df.progress_apply(
lambda row: compute_pppl(
row[args.mutation_col], args.sequence, model, alphabet, args.offset_idx
),
axis=1,
)
df.to_csv(args.dms_output)