in compert/api.py [0:0]
def get_response(
self,
datasets,
doses=None,
contvar_min=None,
contvar_max=None,
n_points=50,
ncells_max=100,
perturbations=None,
control_name='test_control'
):
"""Decoded dose response data frame.
Parameters
----------
dataset : CompPertDataset
The file location of the spreadsheet
doses : np.array (default: None)
Doses values. If None, default values will be generated on a grid:
n_points in range [contvar_min, contvar_max].
contvar_min : float (default: 0)
Minimum dose value to generate for default option.
contvar_max : float (default: 0)
Maximum dose value to generate for default option.
n_points : int (default: 100)
Number of dose points to generate for default option.
perturbations : list (default: None)
List of perturbations for dose response
Returns
-------
pd.DataFrame
of decoded response values of genes and average response.
"""
if contvar_min is None:
contvar_min = self.min_dose
if contvar_max is None:
contvar_max = self.max_dose
self.model.eval()
# doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points))
if doses is None:
doses = np.linspace(contvar_min, contvar_max, n_points)
if perturbations is None:
perturbations = self.unique_perts
response = pd.DataFrame(columns=[self.covars_key,
self.perturbation_key,
self.dose_key,
'response'] + list(self.var_names))
i = 0
for ict, ct in enumerate(self.unique_сovars):
# genes_control = dataset.genes[dataset.indices['control']]
genes_control =\
datasets[control_name].genes[datasets[control_name].cell_types_names ==\
ct].clone().detach()
if len(genes_control) < 1:
print('Warning! Not enought control cells for this covariate.\
Taking control cells from all covariates.')
genes_control = datasets[control_name].genes
if ncells_max < len(genes_control):
ncells_max = min(ncells_max, len(genes_control))
idx = torch.LongTensor(np.random.choice(range(len(genes_control)),\
ncells_max, replace=False))
genes_control = genes_control[idx]
num, dim = genes_control.size(0), genes_control.size(1)
control_avg = genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1)
for idr, drug in enumerate(perturbations):
if not (drug in datasets[control_name].ctrl_name):
for dose in doses:
df = pd.DataFrame(data={self.covars_key: [ct],
self.perturbation_key: [drug], self.dose_key: [dose]})
gene_means, _, _ =\
self.predict(genes_control.cpu().detach().numpy(),\
df, return_anndata=False)
predicted_data = np.mean(gene_means, axis=0).reshape(-1)
response.loc[i] = [ct, drug, dose,
np.linalg.norm(predicted_data-control_avg)] +\
list(predicted_data - control_avg)
i += 1
return response