def __getitem__()

in data.py [0:0]


    def __getitem__(self, index, forward=False):
        FLAGS = self.FLAGS

        if FLAGS.single and not forward:
            index = 0

        FLAGS = self.FLAGS
        pickle_file = self.files[index]

        # node_embed: D x 6
        (node_embed,) = pickle.load(open(pickle_file, "rb"))
        node_embed_original = node_embed

        # Remove proteins with small numbers of atoms
        if node_embed.shape[0] < 20:
            return self.__getitem__((index + 1) % len(self.files), forward=True)

        # Remove invalid proteins
        if (
            node_embed.max(axis=0)[2] >= 21
            or node_embed.max(axis=0)[0] >= 20
            or node_embed.max(axis=0)[1] >= 5
        ):
            return self.__getitem__((index + 1) % len(self.files), forward=True)

        par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)

        if par is None:
            return self.__getitem__((index + 1) % len(self.files), forward=True)

        if len(res) < 5:
            return self.__getitem__((index + 1) % len(self.files), forward=True)

        angles = compute_dihedral(par, child, pos, pos_exist)

        tries = 0
        perm = np.random.permutation(np.arange(1, len(res) - 1))
        select_idxs = []

        while True:
            # Randomly sample an amino acid that are not the first and last amino acid
            idx = perm[tries]
            if res[idx] == "gly" or res[idx] == "ala":
                idx = random.randint(1, len(res) - 2)
            else:
                select_idxs.append(idx)

                if len(select_idxs) == FLAGS.multisample:
                    break

            tries += 1

            if tries > 1000 or tries == perm.shape[0]:
                return self.__getitem__((index + 1) % len(self.files), forward=True)

        node_embeds = []
        node_embeds_negatives = []
        select_atom_idxs = []
        select_atom_masks = []
        select_chis_valids = []
        select_ancestors = []

        for idx in select_idxs:
            neg_samples = []
            gt_chis = [(angles[idx, 4:8], chis_valid[idx, :4])]
            neg_chis = []

            # Choose number of negative samples
            if FLAGS.train and self.split in ["val", "test"]:
                neg_sample = 150
            else:
                neg_sample = FLAGS.neg_sample

            atom_idxs = []
            atoms_mask = []
            chis_valids = []
            ancestors = []

            if self.split == "test":
                dist = np.sqrt(np.square(pos[idx : idx + 1, 2] - pos[:, 2]).sum(axis=1))
                neighbors = (dist < 10).sum()

                # Choose different tresholds of sampling dependent on whether an atom is dense
                # or not
                if neighbors < 24:
                    tresh = 0.95
                else:
                    tresh = 0.98

                if self.weighted_gauss:
                    chis_list = interpolated_sample_normal(
                        self.db,
                        angles[idx, 1],
                        angles[idx, 2],
                        res[idx],
                        neg_sample,
                        uniform=self.uniform,
                    )
                elif self.gmm:
                    chis_list = mixture_sample_normal(
                        self.db,
                        angles[idx, 1],
                        angles[idx, 2],
                        res[idx],
                        neg_sample,
                        uniform=self.uniform,
                    )
                else:
                    chis_list = exhaustive_sample(
                        self.db,
                        angles[idx, 1],
                        angles[idx, 2],
                        res[idx],
                        tresh=tresh,
                        chi_mean=self.chi_mean,
                    )

                    if len(chis_list) < neg_sample:
                        repeat = neg_sample // len(chis_list) + 1
                        chis_list = chis_list * repeat

                    random.shuffle(chis_list)

            else:
                dist = np.sqrt(np.square(pos[idx : idx + 1, 2] - pos[:, 2]).sum(axis=1))
                neighbors = (dist < 10).sum()

                if neighbors < 24:
                    tresh = 1.0
                else:
                    tresh = 1.0

                if self.weighted_gauss:
                    chis_list = interpolated_sample_normal(
                        self.db,
                        angles[idx, 1],
                        angles[idx, 2],
                        res[idx],
                        neg_sample,
                        uniform=self.uniform,
                    )
                elif self.gmm:
                    chis_list = mixture_sample_normal(
                        self.db,
                        angles[idx, 1],
                        angles[idx, 2],
                        res[idx],
                        neg_sample,
                        uniform=self.uniform,
                    )
                else:
                    chis_list = exhaustive_sample(
                        self.db,
                        angles[idx, 1],
                        angles[idx, 2],
                        res[idx],
                        tresh=tresh,
                        chi_mean=self.chi_mean,
                    )

                    if len(chis_list) < neg_sample:
                        repeat = neg_sample // len(chis_list) + 1
                        chis_list = chis_list * repeat

                    random.shuffle(chis_list)

            for i in range(neg_sample):
                chis_target = angles[:, 4:8].copy()
                chis = chis_list[i]

                chis_target[idx] = (
                    chis * chis_valid[idx, :4] + (1 - chis_valid[idx, :4]) * chis_target[idx]
                )
                pos_new = rotate_dihedral_fast(
                    angles, par, child, pos, pos_exist, chis_target, chis_valid, idx
                )

                node_neg_embed = reencode_dense_format(node_embed, pos_new, pos_exist)
                neg_samples.append(node_neg_embed)
                neg_chis.append((chis_target[idx], chis_valid[idx, :4]))
                nelem = pos_exist[:idx].sum()
                offset = pos_exist[idx].sum()
                mask = np.zeros(20)
                mask[:offset] = 1

                atom_idxs.append(
                    np.concatenate(
                        [np.arange(nelem, nelem + offset), np.ones(20 - offset) * (nelem)]
                    )
                )
                atoms_mask.append(mask)
                chis_valids.append(chis_valid[idx, :4].copy())
                ancestors.append(np.stack([par[idx], child[idx]], axis=0))

            node_embed_negative = np.array(neg_samples)

            pos_chosen = pos[idx, 4]

            atoms_mask = np.array(atoms_mask)
            atom_idxs = np.array(atom_idxs)
            chis_valids = np.array(chis_valids)
            ancestors = np.array(ancestors)

            # Choose the closest atoms to the chosen locaiton:
            close_idx = np.argsort(np.square(node_embed[:, -3:] - pos_chosen).sum(axis=1))
            node_embed_short = node_embed[close_idx[: FLAGS.max_size]].copy()

            pos_chosen = pos_new[idx, 4]
            close_idx_neg = np.argsort(
                np.square(node_embed_negative[:, :, -3:] - pos_chosen).sum(axis=2), axis=1
            )

            # Compute the corresponding indices for atom_idxs
            # Get the position of each index ik
            pos_code = np.argsort(close_idx_neg, axis=1)
            choose_idx = np.take_along_axis(pos_code, atom_idxs.astype(np.int32), axis=1)

            if choose_idx.max() >= FLAGS.max_size:
                return self.__getitem__((index + 1) % len(self.files), forward=True)

            node_embed_negative = np.take_along_axis(
                node_embed_negative, close_idx_neg[:, : FLAGS.max_size, None], axis=1
            )

            # Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
            node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean(
                node_embed_short[:, -3:], axis=0
            )
            node_embed_negative[:, :, -3:] = node_embed_negative[:, :, -3:] - np.mean(
                node_embed_negative[:, :, -3:], axis=1, keepdims=True
            )

            if FLAGS.augment:
                # Now rotate all elements
                rot_matrix = self.so3.rvs(1)
                node_embed_short[:, -3:] = np.matmul(node_embed_short[:, -3:], rot_matrix)

                rot_matrix_neg = self.so3.rvs(node_embed_negative.shape[0])
                node_embed_negative[:, :, -3:] = np.matmul(
                    node_embed_negative[:, :, -3:], rot_matrix_neg
                )

            # # Additionally scale values to be in the same scale
            node_embed_short[:, -3:] = node_embed_short[:, -3:] / 10.0
            node_embed_negative[:, :, -3:] = node_embed_negative[:, :, -3:] / 10.0

            # Augment the data with random rotations
            node_embed_short = torch.from_numpy(node_embed_short).float()
            node_embed_negative = torch.from_numpy(node_embed_negative).float()

            if self.split == "train":
                node_embeds.append(node_embed_short)
                node_embeds_negatives.append(node_embed_negative)
            elif self.split in ["val", "test"]:
                return node_embed_short, node_embed_negative, gt_chis, neg_chis, res[idx]

        return node_embeds, node_embeds_negatives