def pair_model()

in vis_sandbox.py [0:0]


def pair_model(model, FLAGS, node_embed):
    # Generate a dataset where two atoms are very close to each other and everything else is very far
    # Indices for atoms
    atom_names = ["X", "C", "N", "O", "S"]
    residue_names = [
        "ALA",
        "ARG",
        "ASN",
        "ASP",
        "CYS",
        "GLU",
        "GLN",
        "GLY",
        "HIS",
        "ILE",
        "LEU",
        "LYS",
        "MET",
        "PHE",
        "PRO",
        "SER",
        "THR",
        "TRP",
        "TYR",
        "VAL",
    ]

    energies_output_dict = {}

    def make_key(n_rotations, residue_name1, residue_name2, atom_name1, atom_name2):
        return f"{n_rotations}_{residue_name1}_{residue_name2}_{atom_name1}_{atom_name2}"

    # Save a copy of the node embed
    node_embed = node_embed[0]
    node_embed_orig = node_embed.clone()

    # Try different combinations
    for n_rotations in [5]:
        # Rotations
        so3 = special_ortho_group(3)
        rot_matrix_neg = so3.rvs(n_rotations)  # number of random rotations to average

        residue_names_proc = ["ALA", "TYR", "LEU"]
        atom_names_proc = ["C", "N", "O"]
        for residue_name1, residue_name2 in itertools.product(residue_names_proc, repeat=2):
            for atom_name1, atom_name2 in itertools.product(atom_names_proc, repeat=2):
                eps = []
                energies = []

                residue_index1 = residue_names.index(residue_name1)
                residue_index2 = residue_names.index(residue_name2)
                atom_index1 = atom_names.index(atom_name1)
                atom_index2 = atom_names.index(atom_name2)

                for i in np.linspace(0.1, 1.0, 100):
                    node_embed = node_embed_orig.clone()
                    node_embed[-2, -3:] = torch.Tensor([1.0, 0.5, 0.5])
                    node_embed[-1, -3:] = torch.Tensor([1.0 + i, 0.5, 0.5])
                    node_embed[-1, 0] = residue_index1
                    node_embed[-2, 0] = residue_index2
                    node_embed[-1, 1] = atom_index1
                    node_embed[-2, 1] = atom_index2
                    node_embed[-1, 2] = 6  # res_counter
                    node_embed[-2, 2] = 6  # res_counter

                    node_embed = np.tile(node_embed[None, :, :], (n_rotations, 1, 1))
                    node_embed[:, :, -3:] = np.matmul(node_embed[:, :, -3:], rot_matrix_neg)
                    node_embed_feed = torch.Tensor(node_embed).cuda()
                    node_embed_feed[:, :, -3:] = node_embed_feed[:, :, -3:] - node_embed_feed[
                        :, :, -3:
                    ].mean(dim=1, keepdim=True)
                    energy = model.forward(node_embed_feed)  #
                    energy = energy.mean()

                    eps.append(i * 10)
                    energies.append(energy.item())

                key = make_key(n_rotations, residue_name1, residue_name2, atom_name1, atom_name2)
                energies_output_dict[key] = (eps, energies)

                # Optionally make plots here -- potentially add conditions to avoid making too many
                plt.plot(eps, energies)
                plt.xlabel("Atom Distance")
                plt.ylabel("Energy")
                plt.title(
                    f"{n_rotations} rots: {atom_name1}, {atom_name2} distance in {residue_name1}/{residue_name2}"
                )
                plt.savefig(
                    f"distance_plots/{n_rotations}_{atom_name1}_{atom_name2}_in_{residue_name1}_{residue_name2}_distance.png"
                )
                plt.clf()

    # Back to outside
    output_path = osp.join(FLAGS.outdir, "atom_distances.p")
    pickle.dump(energies_output_dict, open(output_path, "wb"))