def rotamer_trials()

in vis_sandbox.py [0:0]


def rotamer_trials(model, FLAGS, test_dataset):
    test_files = test_dataset.files
    random.shuffle(test_files)
    db = load_rotamor_library()
    so3 = special_ortho_group(3)

    node_embed_evals = []
    nminibatch = 4

    if FLAGS.ensemble > 1:
        models = model

    # The three different sampling methods are weighted_gauss, gmm, rosetta
    rotamer_scores_total = []
    surface_scores_total = []
    buried_scores_total = []
    amino_recovery_total = {}
    for k, v in kvs.items():
        amino_recovery_total[k.lower()] = []

    counter = 0
    rotations = FLAGS.rotations

    for test_file in tqdm(test_files):
        (node_embed,) = pickle.load(open(test_file, "rb"))
        node_embed_original = node_embed
        par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)
        angles = compute_dihedral(par, child, pos, pos_exist)

        amino_recovery = {}
        for k, v in kvs.items():
            amino_recovery[k.lower()] = []

        if node_embed is None:
            continue

        rotamer_scores = []
        surface_scores = []
        buried_scores = []
        types = []

        gt_chis = []
        node_embed_evals = []
        neg_chis = []
        valid_chi_idxs = []
        res_names = []

        neg_sample = FLAGS.neg_sample

        n_amino = pos.shape[0]
        amino_recovery_curr = {}
        for idx in range(1, n_amino - 1):
            res_name = res[idx]
            if res_name == "gly" or res_name == "ala":
                continue

            res_names.append(res_name)

            gt_chis.append(angles[idx, 4:8])
            valid_chi_idxs.append(chis_valid[idx, :4])

            hacked_pos = np.copy(pos)
            swap_hacked_pos = np.swapaxes(hacked_pos, 0, 1)  # (20, 59, 3)
            idxs_to_change = swap_hacked_pos[4] == [0, 0, 0]  # (59, 3)
            swap_hacked_pos[4][idxs_to_change] = swap_hacked_pos[3][idxs_to_change]
            hacked_pos_final = np.swapaxes(swap_hacked_pos, 0, 1)

            neighbors = np.linalg.norm(pos[idx : idx + 1, 4] - hacked_pos_final[:, 4], axis=1) < 10
            neighbors = neighbors.astype(np.int32).sum()

            if neighbors >= 24:
                types.append("buried")
            elif neighbors < 16:
                types.append("surface")
            else:
                types.append("neutral")

            if neighbors >= 24:
                tresh = 0.98
            else:
                tresh = 0.95

            if FLAGS.sample_mode == "weighted_gauss":
                chis_list = interpolated_sample_normal(
                    db, angles[idx, 1], angles[idx, 2], res[idx], neg_sample, uniform=False
                )
            elif FLAGS.sample_mode == "gmm":
                chis_list = mixture_sample_normal(
                    db, angles[idx, 1], angles[idx, 2], res[idx], neg_sample, uniform=False
                )
            elif FLAGS.sample_mode == "rosetta":
                chis_list = exhaustive_sample(
                    db, angles[idx, 1], angles[idx, 2], res[idx], tresh=tresh
                )

            neg_chis.append(chis_list)

            node_neg_embeds = []
            length_chis = len(chis_list)
            for i in range(neg_sample):
                chis_target = angles[:, 4:8].copy()

                if i >= len(chis_list):
                    node_neg_embed_copy = node_neg_embed.copy()
                    node_neg_embeds.append(node_neg_embeds[i % length_chis])
                    neg_chis[-1].append(chis_list[i % length_chis])
                    continue

                chis = chis_list[i]

                chis_target[idx] = (
                    chis * chis_valid[idx, :4] + (1 - chis_valid[idx, :4]) * chis_target[idx]
                )
                pos_new = rotate_dihedral_fast(
                    angles, par, child, pos, pos_exist, chis_target, chis_valid, idx
                )

                node_neg_embed = reencode_dense_format(node_embed, pos_new, pos_exist)
                node_neg_embeds.append(node_neg_embed)

            node_neg_embeds = np.array(node_neg_embeds)
            dist = np.square(node_neg_embeds[:, :, -3:] - pos[idx : idx + 1, 4:5, :]).sum(axis=2)
            close_idx = np.argsort(dist)
            node_neg_embeds = np.take_along_axis(node_neg_embeds, close_idx[:, :64, None], axis=1)
            node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] / 10.0
            node_neg_embeds[:, :, -3:] = node_neg_embeds[:, :, -3:] - np.mean(
                node_neg_embeds[:, :, -3:], axis=1, keepdims=True
            )

            node_embed_evals.append(node_neg_embeds)

            if len(node_embed_evals) == nminibatch or idx == (n_amino - 2):
                n_entries = len(node_embed_evals)
                node_embed_evals = np.concatenate(node_embed_evals)
                s = node_embed_evals.shape

                # For sample rotations per batch
                node_embed_evals = np.tile(node_embed_evals[:, None, :, :], (1, rotations, 1, 1))
                rot_matrix = so3.rvs(rotations)

                if rotations == 1:
                    rot_matrix = rot_matrix[None, :, :]

                node_embed_evals[:, :, :, -3:] = np.matmul(
                    node_embed_evals[:, :, :, -3:], rot_matrix[None, :, :, :]
                )
                node_embed_evals = node_embed_evals.reshape((-1, *s[1:]))

                node_embed_feed = torch.from_numpy(node_embed_evals).float().cuda()

                with torch.no_grad():
                    energy = 0
                    if FLAGS.ensemble > 1:
                        for model in models:
                            energy_tmp = model.forward(node_embed_feed)
                            energy = energy + energy_tmp
                    else:
                        energy = model.forward(node_embed_feed)

                energy = energy.view(n_entries, -1, rotations).mean(dim=2)
                select_idx = torch.argmin(energy, dim=1).cpu().numpy()

                for i in range(n_entries):
                    select_idx_i = select_idx[i]
                    valid_chi_idx = valid_chi_idxs[i]
                    rotamer_score, _ = compute_rotamer_score_planar(
                        gt_chis[i], neg_chis[i][select_idx_i], valid_chi_idx[:4], res_names[i]
                    )
                    rotamer_scores.append(rotamer_score)

                    amino_recovery[str(res_names[i])] = amino_recovery[str(res_names[i])] + [
                        rotamer_score
                    ]

                    if types[i] == "buried":
                        buried_scores.append(rotamer_score)
                    elif types[i] == "surface":
                        surface_scores.append(rotamer_score)

                gt_chis = []
                node_embed_evals = []
                neg_chis = []
                valid_chi_idxs = []
                res_names = []
                types = []

            counter += 1

        rotamer_scores_total.append(np.mean(rotamer_scores))

        if len(buried_scores) > 0:
            buried_scores_total.append(np.mean(buried_scores))
        surface_scores_total.append(np.mean(surface_scores))

        for k, v in amino_recovery.items():
            if len(v) > 0:
                amino_recovery_total[k] = amino_recovery_total[k] + [np.mean(v)]

        print(
            "Obtained a rotamer recovery score of ",
            np.mean(rotamer_scores_total),
            np.std(rotamer_scores_total) / len(rotamer_scores_total) ** 0.5,
        )
        print(
            "Obtained a buried recovery score of ",
            np.mean(buried_scores_total),
            np.std(buried_scores_total) / len(buried_scores_total) ** 0.5,
        )
        print(
            "Obtained a surface recovery score of ",
            np.mean(surface_scores_total),
            np.std(surface_scores_total) / len(surface_scores_total) ** 0.5,
        )
        for k, v in amino_recovery_total.items():
            print(
                "per amino acid recovery of {} score of ".format(k),
                np.mean(v),
                np.std(v) / len(v) ** 0.5,
            )