in vis_sandbox.py [0:0]
def make_tsne(model, FLAGS, node_embed):
"""
grab representations for each of the residues in a pdb
"""
pdb_name = FLAGS.pdb_name
pickle_file = MMCIF_PATH + f"/mmCIF/{pdb_name[1:3]}/{pdb_name}.p"
(node_embed,) = pickle.load(open(pickle_file, "rb"))
par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)
angles = compute_dihedral(par, child, pos, pos_exist)
all_hiddens = []
all_energies = []
n_rotations = 2
so3 = special_ortho_group(3)
rot_matrix = so3.rvs(n_rotations)
rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
for idx in range(len(res)):
# sort the atoms by how far away they are
# sort key is the first atom on the sidechain
pos_chosen = pos[idx, 4]
close_idx = np.argsort(np.square(node_embed[:, -3:] - pos_chosen).sum(axis=1))
# Grab the 64 closest atoms
node_embed_short = node_embed[close_idx[: FLAGS.max_size]].copy()
# Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean(
node_embed_short[:, -3:], axis=0
)
node_embed_short = torch.from_numpy(node_embed_short).float().cuda()
node_embed_short = node_embed_short[None, :, :].repeat(n_rotations, 1, 1)
node_embed_short[:, :, -3:] = torch.matmul(node_embed_short[:, :, -3:], rot_matrix)
# Compute the energies for the n_rotations * batch_size for this window of 64 atoms.
# Batch the first two dimensions, then pull them apart aftewrads.
# node_embed_short = node_embed_short.reshape(node_embed_short.shape[0] * node_embed_short.shape[1], *node_embed_short.shape[2:])
energies, hidden = model.forward(node_embed_short, return_hidden=True) # (12000, 1)
# all_hiddens.append(hidden.mean(0)) # mean over the rotations
all_hiddens.append(hidden[0]) # take first rotation
all_energies.append(energies[0])
surface_core_type = []
for idx in range(len(res)):
# >16 c-beta neighbors within 10A of its own c-beta (or c-alpha for gly).
hacked_pos = np.copy(pos)
swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1) # (20, 59, 3)
idxs_to_change = swap_hacked_pos[4] == [0, 0, 0] # (59, 3)
swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][idxs_to_change]
hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1)
dist = np.sqrt(
np.square(hacked_pos_final[idx : idx + 1, 4] - hacked_pos_final[:, 4]).sum(axis=1)
)
neighbors = (dist < 10).sum()
if neighbors >= 24:
surface_core_type.append("core")
elif neighbors <= 16:
surface_core_type.append("surface")
else:
surface_core_type.append("unlabeled")
output = {
"res": res,
"surface_core_type": surface_core_type,
"all_hiddens": torch.stack(all_hiddens).cpu().numpy(),
"all_energies": torch.stack(all_energies).cpu().numpy(),
}
# Dump the output
output_path = osp.join(FLAGS.outdir, f"{pdb_name}_representations.p")
if not osp.exists(FLAGS.outdir):
os.makedirs(FLAGS.outdir)
pickle.dump(output, open(output_path, "wb"))