in compert/train.py [0:0]
def evaluate_r2(autoencoder, dataset, genes_control):
"""
Measures different quality metrics about an ComPert `autoencoder`, when
tasked to translate some `genes_control` into each of the drug/cell_type
combinations described in `dataset`.
Considered metrics are R2 score about means and variances for all genes, as
well as R2 score about means and variances about differentially expressed
(_de) genes.
"""
mean_score, var_score, mean_score_de, var_score_de = [], [], [], []
num, dim = genes_control.size(0), genes_control.size(1)
total_cells = len(dataset)
for pert_category in np.unique(dataset.pert_categories):
# pert_category category contains: 'celltype_perturbation_dose' info
de_idx = np.where(
dataset.var_names.isin(
np.array(dataset.de_genes[pert_category])))[0]
idx = np.where(dataset.pert_categories == pert_category)[0]
if len(idx) > 30:
emb_drugs = dataset.drugs[idx][0].view(
1, -1).repeat(num, 1).clone()
emb_cts = dataset.cell_types[idx][0].view(
1, -1).repeat(num, 1).clone()
genes_predict = autoencoder.predict(
genes_control, emb_drugs, emb_cts).detach().cpu()
mean_predict = genes_predict[:, :dim]
var_predict = genes_predict[:, dim:]
# estimate metrics only for reasonably-sized drug/cell-type combos
y_true = dataset.genes[idx, :].numpy()
# true means and variances
yt_m = y_true.mean(axis=0)
yt_v = y_true.var(axis=0)
# predicted means and variances
yp_m = mean_predict.mean(0)
yp_v = var_predict.mean(0)
mean_score.append(r2_score(yt_m, yp_m))
var_score.append(r2_score(yt_v, yp_v))
mean_score_de.append(r2_score(yt_m[de_idx], yp_m[de_idx]))
var_score_de.append(r2_score(yt_v[de_idx], yp_v[de_idx]))
return [np.mean(s) if len(s) else -1
for s in [mean_score, mean_score_de, var_score, var_score_de]]