def make_tsne()

in vis_sandbox.py [0:0]


def make_tsne(model, FLAGS, node_embed):
    """
    grab representations for each of the residues in a pdb
    """
    pdb_name = FLAGS.pdb_name
    pickle_file = MMCIF_PATH + f"/mmCIF/{pdb_name[1:3]}/{pdb_name}.p"
    (node_embed,) = pickle.load(open(pickle_file, "rb"))
    par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)
    angles = compute_dihedral(par, child, pos, pos_exist)

    all_hiddens = []
    all_energies = []

    n_rotations = 2
    so3 = special_ortho_group(3)
    rot_matrix = so3.rvs(n_rotations)
    rot_matrix = torch.from_numpy(rot_matrix).float().cuda()

    for idx in range(len(res)):
        # sort the atoms by how far away they are
        # sort key is the first atom on the sidechain
        pos_chosen = pos[idx, 4]
        close_idx = np.argsort(np.square(node_embed[:, -3:] - pos_chosen).sum(axis=1))

        # Grab the 64 closest atoms
        node_embed_short = node_embed[close_idx[: FLAGS.max_size]].copy()

        # Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
        node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean(
            node_embed_short[:, -3:], axis=0
        )
        node_embed_short = torch.from_numpy(node_embed_short).float().cuda()

        node_embed_short = node_embed_short[None, :, :].repeat(n_rotations, 1, 1)
        node_embed_short[:, :, -3:] = torch.matmul(node_embed_short[:, :, -3:], rot_matrix)

        # Compute the energies for the n_rotations * batch_size for this window of 64 atoms.
        # Batch the first two dimensions, then pull them apart aftewrads.
        # node_embed_short = node_embed_short.reshape(node_embed_short.shape[0] * node_embed_short.shape[1], *node_embed_short.shape[2:])

        energies, hidden = model.forward(node_embed_short, return_hidden=True)  # (12000, 1)

        # all_hiddens.append(hidden.mean(0)) # mean over the rotations
        all_hiddens.append(hidden[0])  # take first rotation
        all_energies.append(energies[0])

    surface_core_type = []
    for idx in range(len(res)):
        # >16 c-beta neighbors within 10A of its own c-beta (or c-alpha for gly).
        hacked_pos = np.copy(pos)
        swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1)  # (20, 59, 3)
        idxs_to_change = swap_hacked_pos[4] == [0, 0, 0]  # (59, 3)
        swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][idxs_to_change]
        hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1)

        dist = np.sqrt(
            np.square(hacked_pos_final[idx : idx + 1, 4] - hacked_pos_final[:, 4]).sum(axis=1)
        )
        neighbors = (dist < 10).sum()

        if neighbors >= 24:
            surface_core_type.append("core")
        elif neighbors <= 16:
            surface_core_type.append("surface")
        else:
            surface_core_type.append("unlabeled")

    output = {
        "res": res,
        "surface_core_type": surface_core_type,
        "all_hiddens": torch.stack(all_hiddens).cpu().numpy(),
        "all_energies": torch.stack(all_energies).cpu().numpy(),
    }
    # Dump the output
    output_path = osp.join(FLAGS.outdir, f"{pdb_name}_representations.p")
    if not osp.exists(FLAGS.outdir):
        os.makedirs(FLAGS.outdir)
    pickle.dump(output, open(output_path, "wb"))