in compert/api.py [0:0]
def get_response_reference(
self,
datasets,
perturbations=None
):
"""Computes reference values of the response.
Parameters
----------
dataset : CompPertDataset
The file location of the spreadsheet
perturbations : list (default: None)
List of perturbations for dose response
Returns
-------
pd.DataFrame
of decoded response values of genes and average response.
"""
if perturbations is None:
perturbations = self.unique_perts
reference_response_curve = pd.DataFrame(columns=[self.covars_key,
self.perturbation_key,
self.dose_key,
'split',
'num_cells',
'response'] +\
list(self.var_names))
dataset_ctr = datasets['training_control']
i = 0
for split in ['training_treated', 'ood']:
dataset = datasets[split]
for pert in self.seen_covars_perts[split]:
ct, drug, dose_val = pert.split('_')
if drug in perturbations:
if not ('+' in dose_val):
dose = float(dose_val)
else:
dose = dose_val
genes_control = dataset_ctr.genes[
(dataset_ctr.cell_types_names == ct)].clone().detach()
if len(genes_control) < 1:
print('Warning! Not enought control cells for this covariate. \
Taking control cells from all covariates.')
genes_control = dataset_ctr.genes.clone().detach()
num, dim = genes_control.size(0), genes_control.size(1)
control_avg =\
genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1)
idx = np.where((dataset.pert_categories == pert))[0]
if len(idx):
y_true = dataset.genes[idx, :].numpy().mean(axis=0)
reference_response_curve.loc[i] = [ct, drug,
dose, split, len(idx), np.linalg.norm(y_true - control_avg)] +\
list(y_true - control_avg)
i += 1
return reference_response_curve