def compute_comb_emb()

in compert/api.py [0:0]


    def compute_comb_emb(self, thrh=30):
        """
        Generates an AnnData object containing all the latent vectors of the
        cov+dose*pert combinations seen during training.
        Called in api.compute_uncertainty(), stores the AnnData in self.comb_emb.

        Parameters
        ----------
        Returns
        -------
        """
        if self.seen_covars_perts['training'] is None:
            raise ValueError('Need to run parse_training_conditions() first!')

        emb_covars = self.get_covars_embeddings(return_anndata=True)

        #Generate adata with all cov+pert latent vect combinations
        tmp_ad_list = []
        for cov_pert in self.seen_covars_perts['training']:
            if self.num_measured_points['training'][cov_pert] > thrh:
                cov_loop, pert_loop, dose_loop = cov_pert.split('_')
                emb_perts_loop = []
                if '+' in pert_loop:
                    pert_loop_list = pert_loop.split('+')
                    dose_loop_list = dose_loop.split('+')
                    for _dose in pd.Series(dose_loop_list).unique():
                        tmp_ad = self.get_drug_embeddings(dose=float(_dose))
                        tmp_ad.obs['pert_dose'] = tmp_ad.obs.condition + '_' + _dose
                        emb_perts_loop.append(tmp_ad)

                    emb_perts_loop = emb_perts_loop[0].concatenate(emb_perts_loop[1:])
                    X = (
                        emb_covars.X[emb_covars.obs.cell_type == cov_loop]
                        + np.expand_dims(
                            emb_perts_loop.X[
                                emb_perts_loop.obs.pert_dose.isin(
                                    [
                                        pert_loop_list[i] + '_' + dose_loop_list[i]
                                        for i in range(len(pert_loop_list))
                                    ]
                                )
                            ].sum(axis=0),
                            axis=0
                        )
                    )
                    if X.shape[0] > 1:
                        raise ValueError('Error with comb computation')
                else:
                    emb_perts = self.get_drug_embeddings(dose=float(dose_loop))
                    X = (
                        emb_covars.X[emb_covars.obs.cell_type == cov_loop]
                        + emb_perts.X[emb_perts.obs.condition == pert_loop]
                    )
                tmp_ad = sc.AnnData(
                    X=X
                )
                tmp_ad.obs['cov_pert'] = '_'.join([cov_loop, pert_loop, dose_loop])
            tmp_ad_list.append(tmp_ad)

        self.comb_emb = tmp_ad_list[0].concatenate(tmp_ad_list[1:])