in vis_sandbox.py [0:0]
def rotamer_trials(model, FLAGS, test_dataset):
test_files = test_dataset.files
random.shuffle(test_files)
db = load_rotamor_library()
so3 = special_ortho_group(3)
node_embed_evals = []
nminibatch = 4
if FLAGS.ensemble > 1:
models = model
# The three different sampling methods are weighted_gauss, gmm, rosetta
rotamer_scores_total = []
surface_scores_total = []
buried_scores_total = []
amino_recovery_total = {}
for k, v in kvs.items():
amino_recovery_total[k.lower()] = []
counter = 0
rotations = FLAGS.rotations
for test_file in tqdm(test_files):
(node_embed,) = pickle.load(open(test_file, "rb"))
node_embed_original = node_embed
par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)
angles = compute_dihedral(par, child, pos, pos_exist)
amino_recovery = {}
for k, v in kvs.items():
amino_recovery[k.lower()] = []
if node_embed is None:
continue
rotamer_scores = []
surface_scores = []
buried_scores = []
types = []
gt_chis = []
node_embed_evals = []
neg_chis = []
valid_chi_idxs = []
res_names = []
neg_sample = FLAGS.neg_sample
n_amino = pos.shape[0]
amino_recovery_curr = {}
for idx in range(1, n_amino - 1):
res_name = res[idx]
if res_name == "gly" or res_name == "ala":
continue
res_names.append(res_name)
gt_chis.append(angles[idx, 4:8])
valid_chi_idxs.append(chis_valid[idx, :4])
hacked_pos = np.copy(pos)
swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1) # (20, 59, 3)
idxs_to_change = swap_hacked_pos[4] == [0, 0, 0] # (59, 3)
swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][idxs_to_change]
hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1)
neighbors = np.linalg.norm(pos[idx : idx + 1, 4] - hacked_pos_final[:, 4], axis=1) < 10
neighbors = neighbors.astype(np.int32).sum()
if neighbors >= 24:
types.append("buried")
elif neighbors < 16:
types.append("surface")
else:
types.append("neutral")
if neighbors >= 24:
tresh = 0.98
else:
tresh = 0.95
if FLAGS.sample_mode == "weighted_gauss":
chis_list = interpolated_sample_normal(
db, angles[idx, 1], angles[idx, 2], res[idx], neg_sample, uniform=False
)
elif FLAGS.sample_mode == "gmm":
chis_list = mixture_sample_normal(
db, angles[idx, 1], angles[idx, 2], res[idx], neg_sample, uniform=False
)
elif FLAGS.sample_mode == "rosetta":
chis_list = exhaustive_sample(
db, angles[idx, 1], angles[idx, 2], res[idx], tresh=tresh
)
neg_chis.append(chis_list)
node_neg_embeds = []
length_chis = len(chis_list)
for i in range(neg_sample):
chis_target = angles[:, 4:8].copy()
if i >= len(chis_list):
node_neg_embed_copy = node_neg_embed.copy()
node_neg_embeds.append(node_neg_embeds[i % length_chis])
neg_chis[-1].append(chis_list[i % length_chis])
continue
chis = chis_list[i]
chis_target[idx] = (
chis * chis_valid[idx, :4] + (1 - chis_valid[idx, :4]) * chis_target[idx]
)
pos_new = rotate_dihedral_fast(
angles, par, child, pos, pos_exist, chis_target, chis_valid, idx
)
node_neg_embed = reencode_dense_format(node_embed, pos_new, pos_exist)
node_neg_embeds.append(node_neg_embed)
node_neg_embeds = np.array(node_neg_embeds)
dist = np.square(node_neg_embeds[:, :, -3:] - pos[idx : idx + 1, 4:5, :]).sum(axis=2)
close_idx = np.argsort(dist)
node_neg_embeds = np.take_along_axis(node_neg_embeds, close_idx[:, :64, None], axis=1)
node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] / 10.0
node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] - np.mean(
node_neg_embeds[:, :, -3:], axis=1, keepdims=True
)
node_embed_evals.append(node_neg_embeds)
if len(node_embed_evals) == nminibatch or idx == (n_amino - 2):
n_entries = len(node_embed_evals)
node_embed_evals = np.concatenate(node_embed_evals)
s = node_embed_evals.shape
# For sample rotations per batch
node_embed_evals = np.tile(node_embed_evals[:, None, :, :], (1, rotations, 1, 1))
rot_matrix = so3.rvs(rotations)
if rotations == 1:
rot_matrix = rot_matrix[None, :, :]
node_embed_evals[:, :, :, -3:] = np.matmul(
node_embed_evals[:, :, :, -3:], rot_matrix[None, :, :, :]
)
node_embed_evals = node_embed_evals.reshape((-1, *s[1:]))
node_embed_feed = torch.from_numpy(node_embed_evals).float().cuda()
with torch.no_grad():
energy = 0
if FLAGS.ensemble > 1:
for model in models:
energy_tmp = model.forward(node_embed_feed)
energy = energy + energy_tmp
else:
energy = model.forward(node_embed_feed)
energy = energy.view(n_entries, -1, rotations).mean(dim=2)
select_idx = torch.argmin(energy, dim=1).cpu().numpy()
for i in range(n_entries):
select_idx_i = select_idx[i]
valid_chi_idx = valid_chi_idxs[i]
rotamer_score, _ = compute_rotamer_score_planar(
gt_chis[i], neg_chis[i][select_idx_i], valid_chi_idx[:4], res_names[i]
)
rotamer_scores.append(rotamer_score)
amino_recovery[str(res_names[i])] = amino_recovery[str(res_names[i])] + [
rotamer_score
]
if types[i] == "buried":
buried_scores.append(rotamer_score)
elif types[i] == "surface":
surface_scores.append(rotamer_score)
gt_chis = []
node_embed_evals = []
neg_chis = []
valid_chi_idxs = []
res_names = []
types = []
counter += 1
rotamer_scores_total.append(np.mean(rotamer_scores))
if len(buried_scores) > 0:
buried_scores_total.append(np.mean(buried_scores))
surface_scores_total.append(np.mean(surface_scores))
for k, v in amino_recovery.items():
if len(v) > 0:
amino_recovery_total[k] = amino_recovery_total[k] + [np.mean(v)]
print(
"Obtained a rotamer recovery score of ",
np.mean(rotamer_scores_total),
np.std(rotamer_scores_total) / len(rotamer_scores_total) ** 0.5,
)
print(
"Obtained a buried recovery score of ",
np.mean(buried_scores_total),
np.std(buried_scores_total) / len(buried_scores_total) ** 0.5,
)
print(
"Obtained a surface recovery score of ",
np.mean(surface_scores_total),
np.std(surface_scores_total) / len(surface_scores_total) ** 0.5,
)
for k, v in amino_recovery_total.items():
print(
"per amino acid recovery of {} score of ".format(k),
np.mean(v),
np.std(v) / len(v) ** 0.5,
)