def get_response()

in compert/api.py [0:0]


    def get_response(
        self,
        datasets,
        doses=None,
        contvar_min=None,
        contvar_max=None,
        n_points=50,
        ncells_max=100,
        perturbations=None,
        control_name='test_control'
        ):
        """Decoded dose response data frame.

        Parameters
        ----------
        dataset : CompPertDataset
            The file location of the spreadsheet
        doses : np.array (default: None)
            Doses values. If None, default values will be generated on a grid:
            n_points in range [contvar_min, contvar_max].
        contvar_min : float (default: 0)
            Minimum dose value to generate for default option.
        contvar_max : float (default: 0)
            Maximum dose value to generate for default option.
        n_points : int (default: 100)
            Number of dose points to generate for default option.
        perturbations : list (default: None)
            List of perturbations for dose response

        Returns
        -------
        pd.DataFrame
            of decoded response values of genes and average response.
        """

        if contvar_min is None:
            contvar_min = self.min_dose
        if contvar_max is None:
            contvar_max = self.max_dose

        self.model.eval()
        # doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points))
        if doses is None:
            doses = np.linspace(contvar_min, contvar_max, n_points)

        if perturbations is None:
            perturbations = self.unique_perts

        response = pd.DataFrame(columns=[self.covars_key,
                                        self.perturbation_key,
                                        self.dose_key,
                                        'response'] + list(self.var_names))

        i = 0
        for ict, ct in enumerate(self.unique_сovars):
            # genes_control = dataset.genes[dataset.indices['control']]
            genes_control =\
                datasets[control_name].genes[datasets[control_name].cell_types_names ==\
                     ct].clone().detach()
            if len(genes_control) < 1:
                print('Warning! Not enought control cells for this covariate.\
                    Taking control cells from all covariates.')
                genes_control = datasets[control_name].genes

            if ncells_max < len(genes_control):
                ncells_max = min(ncells_max, len(genes_control))
                idx = torch.LongTensor(np.random.choice(range(len(genes_control)),\
                    ncells_max, replace=False))
                genes_control = genes_control[idx]

            num, dim = genes_control.size(0), genes_control.size(1)
            control_avg = genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1)

            for idr, drug in enumerate(perturbations):
                if not (drug in datasets[control_name].ctrl_name):
                    for dose in doses:
                        df = pd.DataFrame(data={self.covars_key: [ct],
                            self.perturbation_key: [drug], self.dose_key: [dose]})

                        gene_means, _, _ =\
                            self.predict(genes_control.cpu().detach().numpy(),\
                                df, return_anndata=False)
                        predicted_data = np.mean(gene_means, axis=0).reshape(-1)

                        response.loc[i] = [ct, drug, dose,
                            np.linalg.norm(predicted_data-control_avg)] +\
                                list(predicted_data - control_avg)
                        i += 1
        return response