in vis_sandbox.py [0:0]
def pair_model(model, FLAGS, node_embed):
# Generate a dataset where two atoms are very close to each other and everything else is very far
# Indices for atoms
atom_names = ["X", "C", "N", "O", "S"]
residue_names = [
"ALA",
"ARG",
"ASN",
"ASP",
"CYS",
"GLU",
"GLN",
"GLY",
"HIS",
"ILE",
"LEU",
"LYS",
"MET",
"PHE",
"PRO",
"SER",
"THR",
"TRP",
"TYR",
"VAL",
]
energies_output_dict = {}
def make_key(n_rotations, residue_name1, residue_name2, atom_name1, atom_name2):
return f"{n_rotations}_{residue_name1}_{residue_name2}_{atom_name1}_{atom_name2}"
# Save a copy of the node embed
node_embed = node_embed[0]
node_embed_orig = node_embed.clone()
# Try different combinations
for n_rotations in [5]:
# Rotations
so3 = special_ortho_group(3)
rot_matrix_neg = so3.rvs(n_rotations) # number of random rotations to average
residue_names_proc = ["ALA", "TYR", "LEU"]
atom_names_proc = ["C", "N", "O"]
for residue_name1, residue_name2 in itertools.product(residue_names_proc, repeat=2):
for atom_name1, atom_name2 in itertools.product(atom_names_proc, repeat=2):
eps = []
energies = []
residue_index1 = residue_names.index(residue_name1)
residue_index2 = residue_names.index(residue_name2)
atom_index1 = atom_names.index(atom_name1)
atom_index2 = atom_names.index(atom_name2)
for i in np.linspace(0.1, 1.0, 100):
node_embed = node_embed_orig.clone()
node_embed[-2, -3:] = torch.Tensor([1.0, 0.5, 0.5])
node_embed[-1, -3:] = torch.Tensor([1.0 + i, 0.5, 0.5])
node_embed[-1, 0] = residue_index1
node_embed[-2, 0] = residue_index2
node_embed[-1, 1] = atom_index1
node_embed[-2, 1] = atom_index2
node_embed[-1, 2] = 6 # res_counter
node_embed[-2, 2] = 6 # res_counter
node_embed = np.tile(node_embed[None, :, :], (n_rotations, 1, 1))
node_embed[:, :, -3:] = np.matmul(node_embed[:, :, -3:], rot_matrix_neg)
node_embed_feed = torch.Tensor(node_embed).cuda()
node_embed_feed[:, :, -3:] = node_embed_feed[:, :, -3:] - node_embed_feed[
:, :, -3:
].mean(dim=1, keepdim=True)
energy = model.forward(node_embed_feed) #
energy = energy.mean()
eps.append(i * 10)
energies.append(energy.item())
key = make_key(n_rotations, residue_name1, residue_name2, atom_name1, atom_name2)
energies_output_dict[key] = (eps, energies)
# Optionally make plots here -- potentially add conditions to avoid making too many
plt.plot(eps, energies)
plt.xlabel("Atom Distance")
plt.ylabel("Energy")
plt.title(
f"{n_rotations} rots: {atom_name1}, {atom_name2} distance in {residue_name1}/{residue_name2}"
)
plt.savefig(
f"distance_plots/{n_rotations}_{atom_name1}_{atom_name2}_in_{residue_name1}_{residue_name2}_distance.png"
)
plt.clf()
# Back to outside
output_path = osp.join(FLAGS.outdir, "atom_distances.p")
pickle.dump(energies_output_dict, open(output_path, "wb"))