def predict()

in compert/api.py [0:0]


    def predict(
        self,
        genes,
        df,
        uncertainty=True,
        return_anndata=True,
        sample=False,
        n_samples=10
        ):
        """Predict values of control 'genes' conditions specified in df.

        Parameters
        ----------
        genes : np.array
            Control cells.
        df : pd.DataFrame
            Values for perturbations and covariates to generate.
        uncertainty: bool (default: True)
            Compute uncertainties for the generated cells.
        return_anndata : bool, optional (default: True)
            Return embedding wrapped into anndata object.
        sample : bool (default: False)
            If sample is True, returns samples from gausssian distribution with
            mean and variance estimated by the model. Otherwise, returns just
            means and variances estimated by the model.
        n_samples : int (default: 10)
            Number of samples to sample if sampling is True.
        Returns
        -------
        If return_anndata is True, returns anndata structure. Otherwise, returns
        np.arrays for gene_means, gene_vars and a data frame for the corresponding
        conditions df_obs.

        """
        self.model.eval()
        num = genes.shape[0]
        dim = genes.shape[1]
        genes = torch.Tensor(genes).to(self.model.device)
        if sample:
            print('Careful! These are sampled values! Better use means and \
                variances for dowstream tasks!')

        gene_means_list = []
        gene_vars_list = []
        df_list = []

        for i in range(len(df)):
            comb_name = df.loc[i][self.perturbation_key]
            dose_name = df.loc[i][self.dose_key]
            covar_name = df.loc[i][self.covars_key]

            covar_ohe = torch.Tensor(
                self.covars_dict[covar_name]
            ).to(self.model.device)

            drug_ohe = torch.Tensor(
                self.get_drug_encoding_(
                    comb_name,
                    doses=dose_name
                )
            ).to(self.model.device)

            drugs = drug_ohe.expand([num, self.drug_ohe.shape[1]])
            covars = covar_ohe.expand([num, self.covars_ohe.shape[1]])

            gene_reconstructions = self.model.predict(
                genes,
                drugs,
                covars
            ).cpu().clone().detach().numpy()

            if sample:
                df_list.append(
                    pd.DataFrame(
                        [df.loc[i].values]*num*n_samples,
                        columns=df.columns
                    )
                )
                dist = torch.distributions.normal.Normal(
                    torch.Tensor(gene_reconstructions[:, :dim]),
                    torch.Tensor(gene_reconstructions[:, dim:]),
                )
                gene_means_list.append(
                    dist
                    .sample(torch.Size([n_samples]))
                    .cpu()
                    .detach()
                    .numpy()
                    .reshape(-1, dim)
                )
            else:
                df_list.append(
                    pd.DataFrame(
                        [df.loc[i].values]*num,
                        columns=df.columns
                    )
                )

                gene_means_list.append(
                    gene_reconstructions[:, :dim]
                )

            if uncertainty:
                cos_dist, eucl_dist, closest_cond_cos, closest_cond_eucl =\
                    self.compute_uncertainty(
                    cov=covar_name,
                    pert=comb_name,
                    dose=dose_name
                )
                df_list[-1] = df_list[-1].assign(
                    uncertainty_cosine=cos_dist,
                    uncertainty_euclidean=eucl_dist,
                    closest_cond_cosine=closest_cond_cos,
                    closest_cond_euclidean=closest_cond_eucl
                )
            gene_vars_list.append(
                gene_reconstructions[:, dim:]
            )

        gene_means = np.concatenate(gene_means_list)
        gene_vars = np.concatenate(gene_vars_list)
        df_obs = pd.concat(df_list)
        del df_list, gene_means_list, gene_vars_list

        if return_anndata:
            adata = sc.AnnData(gene_means)
            adata.var_names = self.var_names
            adata.obs = df_obs
            if not sample:
                adata.layers["variance"] = gene_vars

            adata.obs.index = adata.obs.index.astype(str)  # type fix
            return adata
        else:
            return gene_means, gene_vars, df_obs