in compert/api.py [0:0]
def get_response2D(
self,
datasets,
perturbations,
covar,
doses=None,
contvar_min=None,
contvar_max=None,
n_points=10,
ncells_max=100,
fixed_drugs='',
fixed_doses=''
):
"""Decoded dose response data frame.
Parameters
----------
dataset : CompPertDataset
The file location of the spreadsheet
perturbations : list
List of length 2 of perturbations for dose response.
covar : str
Name of a covariate for which to compute dose-response.
doses : np.array (default: None)
Doses values. If None, default values will be generated on a grid:
n_points in range [contvar_min, contvar_max].
contvar_min : float (default: 0)
Minimum dose value to generate for default option.
contvar_max : float (default: 0)
Maximum dose value to generate for default option.
n_points : int (default: 100)
Number of dose points to generate for default option.
Returns
-------
pd.DataFrame
of decoded response values of genes and average response.
"""
assert len(perturbations) == 2, "You should provide a list of 2 perturbations."
if contvar_min is None:
contvar_min = self.min_dose
if contvar_max is None:
contvar_max = self.max_dose
self.model.eval()
# doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points))
if doses is None:
doses = np.linspace(contvar_min, contvar_max, n_points)
# genes_control = dataset.genes[dataset.indices['control']]
genes_control =\
datasets['test_control'].genes[datasets['test_control'].cell_types_names ==\
covar].clone().detach()
if len(genes_control) < 1:
print('Warning! Not enought control cells for this covariate. \
Taking control cells from all covariates.')
genes_control = datasets['test_control'].genes
ncells_max = min(ncells_max, len(genes_control))
idx = torch.LongTensor(np.random.choice(range(len(genes_control)), ncells_max))
genes_control = genes_control[idx]
num, dim = genes_control.size(0), genes_control.size(1)
control_avg = genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1)
response = pd.DataFrame(columns=perturbations + ['response'] +\
list(self.var_names))
drug = perturbations[0] + '+' + perturbations[1]
dose_vals = [f"{d[0]}+{d[1]}" for d in itertools.product(*[doses, doses])]
dose_comb = [list(d) for d in itertools.product(*[doses, doses])]
i = 0
if not (drug in ['Vehicle', 'EGF', 'unst', 'control', 'ctrl']):
for dose in dose_vals:
df = pd.DataFrame(data={self.covars_key: [covar],
self.perturbation_key: [drug+fixed_drugs],\
self.dose_key: [dose+fixed_doses]})
gene_means, _, _ = self.predict(
genes_control.cpu().detach().numpy(), df,
return_anndata=False)
predicted_data = np.mean(gene_means, axis=0).reshape(-1)
response.loc[i] = dose_comb[i] +\
[np.linalg.norm(control_avg - predicted_data)] +\
list(predicted_data - control_avg)
i += 1
return response