def new_model()

in vis_sandbox.py [0:0]


def new_model(model, FLAGS, node_embed):
    BATCH_SIZE = 120
    pdb_name = FLAGS.pdb_name  #'6mdw.0'
    pickle_file = f"/private/home/yilundu/dataset/mmcif/mmCIF/{pdb_name[1:3]}/{pdb_name}.p"
    (node_embed,) = pickle.load(open(pickle_file, "rb"))
    par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)
    angles = compute_dihedral(par, child, pos, pos_exist)

    chis_target_initial = angles[
        :, 4:8
    ].copy()  # dihedral for backbone (:4); dihedral for sidechain (4:8)

    NUM_RES = len(res)
    all_energies = np.empty((NUM_RES, 4, 360))  # 4 is number of possible chi angles

    surface_core_type = []
    for idx in range(NUM_RES):
        dist = np.sqrt(np.square(pos[idx : idx + 1, 2] - pos[:, 2]).sum(axis=1))
        neighbors = (dist < 10).sum()
        if neighbors >= 24:
            surface_core_type.append("core")
        elif neighbors <= 16:
            surface_core_type.append("surface")
        else:
            surface_core_type.append("unlabeled")

    for idx in tqdm(range(NUM_RES)):
        for chi_num in range(4):
            if not chis_valid[idx, chi_num]:
                continue

            # init_angle = chis_target[idx, chi_num]
            for angle_deltas in batch(range(-180, 180, 3), BATCH_SIZE):
                pre_rot_node_embed_short = []
                for angle_delta in angle_deltas:
                    chis_target = chis_target_initial.copy()  # make a local copy

                    # modify the angle by angle_delta amount. rotate to chis_target
                    chis_target[
                        idx, chi_num
                    ] += angle_delta  # Set the specific chi angle to be the sampled value

                    # pos_new is n residues x 20 atoms x 3 (xyz)
                    pos_new = rotate_dihedral_fast(
                        angles, par, child, pos, pos_exist, chis_target, chis_valid, idx
                    )
                    node_neg_embed = reencode_dense_format(node_embed, pos_new, pos_exist)

                    # sort the atoms by how far away they are
                    # sort key is the first atom on the sidechain
                    pos_chosen = pos_new[idx, 4]
                    close_idx = np.argsort(
                        np.square(node_neg_embed[:, -3:] - pos_chosen).sum(axis=1)
                    )

                    # Grab the 64 closest atoms
                    node_embed_short = node_neg_embed[close_idx[: FLAGS.max_size]].copy()

                    # Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
                    node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean(
                        node_embed_short[:, -3:], axis=0
                    )
                    node_embed_short = torch.from_numpy(node_embed_short).float().cuda()
                    pre_rot_node_embed_short.append(node_embed_short.unsqueeze(0))
                pre_rot_node_embed_short = torch.stack(pre_rot_node_embed_short)

                # Now rotate all elements
                n_rotations = 100
                so3 = special_ortho_group(3)
                rot_matrix = so3.rvs(n_rotations)  # n x 3 x 3
                node_embed_short = pre_rot_node_embed_short.repeat(1, n_rotations, 1, 1)
                rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
                node_embed_short[:, :, :, -3:] = torch.matmul(
                    node_embed_short[:, :, :, -3:], rot_matrix
                )  # (batch_size, n_rotations, 64, 20)

                # Compute the energies for the n_rotations * batch_size for this window of 64 atoms.
                # Batch the first two dimensions, then pull them apart aftewrads.
                node_embed_short = node_embed_short.reshape(
                    node_embed_short.shape[0] * node_embed_short.shape[1],
                    *node_embed_short.shape[2:],
                )
                energies = model.forward(node_embed_short)  # (12000, 1)

                # divide the batch dimension by the 10 things we just did
                energies = energies.reshape(BATCH_SIZE, -1)  # (10, 200)

                # Average the energy across the n_rotations, but keeping batch-wise seperate
                energies = energies.mean(1)  # (10, 1)

                # Save the result
                all_energies[idx, chi_num, angle_deltas] = energies.cpu().numpy()

    # Can use these for processing later.
    avg_chi_angle_energy = (all_energies * chis_valid[:NUM_RES, :4, None]).sum(0) / np.expand_dims(
        chis_valid[:NUM_RES, :4].sum(0), 1
    )  # normalize by how many times each chi angle occurs
    output = {
        "all_energies": all_energies,
        "chis_valid": chis_valid,
        "chis_target_initial": chis_target_initial,
        "avg_chi_angle_energy": avg_chi_angle_energy,  # make four plots from this (4, 360),
        "res": res,
        "surface_core_type": surface_core_type,
    }
    # Dump the output
    output_path = osp.join(FLAGS.outdir, f"{pdb_name}_rot_energies.p")
    if not osp.exists(FLAGS.outdir):
        os.makedirs(FLAGS.outdir)
    pickle.dump(output, open(output_path, "wb"))