def get_response2D()

in compert/api.py [0:0]


    def get_response2D(
        self,
        datasets,
        perturbations,
        covar,
        doses=None,
        contvar_min=None,
        contvar_max=None,
        n_points=10,
        ncells_max=100,
        fixed_drugs='',
        fixed_doses=''
        ):
        """Decoded dose response data frame.

        Parameters
        ----------
        dataset : CompPertDataset
            The file location of the spreadsheet
        perturbations : list
            List of length 2 of perturbations for dose response.
        covar : str
            Name of a covariate for which to compute dose-response.
        doses : np.array (default: None)
            Doses values. If None, default values will be generated on a grid:
            n_points in range [contvar_min, contvar_max].
        contvar_min : float (default: 0)
            Minimum dose value to generate for default option.
        contvar_max : float (default: 0)
            Maximum dose value to generate for default option.
        n_points : int (default: 100)
            Number of dose points to generate for default option.

        Returns
        -------
        pd.DataFrame
            of decoded response values of genes and average response.
        """

        assert len(perturbations) == 2, "You should provide a list of 2 perturbations."

        if contvar_min is None:
            contvar_min = self.min_dose

        if contvar_max is None:
            contvar_max = self.max_dose

        self.model.eval()
        # doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points))
        if doses is None:
            doses = np.linspace(contvar_min, contvar_max, n_points)

        # genes_control = dataset.genes[dataset.indices['control']]
        genes_control =\
            datasets['test_control'].genes[datasets['test_control'].cell_types_names ==\
                 covar].clone().detach()
        if len(genes_control) < 1:
            print('Warning! Not enought control cells for this covariate. \
                Taking control cells from all covariates.')
            genes_control = datasets['test_control'].genes

        ncells_max = min(ncells_max, len(genes_control))
        idx = torch.LongTensor(np.random.choice(range(len(genes_control)), ncells_max))
        genes_control = genes_control[idx]

        num, dim = genes_control.size(0), genes_control.size(1)
        control_avg = genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1)

        response = pd.DataFrame(columns=perturbations + ['response'] +\
            list(self.var_names))

        drug = perturbations[0] + '+' + perturbations[1]

        dose_vals = [f"{d[0]}+{d[1]}" for d in itertools.product(*[doses, doses])]
        dose_comb = [list(d) for d in itertools.product(*[doses, doses])]

        i = 0
        if not (drug in ['Vehicle', 'EGF', 'unst', 'control', 'ctrl']):
            for dose in dose_vals:
                df = pd.DataFrame(data={self.covars_key: [covar],
                    self.perturbation_key: [drug+fixed_drugs],\
                        self.dose_key: [dose+fixed_doses]})

                gene_means, _, _ = self.predict(
                    genes_control.cpu().detach().numpy(), df,
                    return_anndata=False)

                predicted_data = np.mean(gene_means, axis=0).reshape(-1)

                response.loc[i] = dose_comb[i] +\
                    [np.linalg.norm(control_avg - predicted_data)] +\
                    list(predicted_data - control_avg)
                i += 1

        return response