in compert/api.py [0:0]
def evaluate_r2_benchmark(
compert_api,
datasets,
pert_category,
pert_category_list
):
scores = pd.DataFrame(columns=[compert_api.covars_key,
compert_api.perturbation_key,
compert_api.dose_key,
'R2_mean', 'R2_mean_DE',
'R2_var', 'R2_var_DE',
'num_cells', 'benchmark', 'method'])
de_idx = np.where(
datasets['ood'].var_names.isin(
np.array(datasets['ood'].de_genes[pert_category])))[0]
idx = np.where(datasets['ood'].pert_categories == pert_category)[0]
y_true = datasets['ood'].genes[idx, :].numpy()
# true means and variances
yt_m = y_true.mean(axis=0)
yt_v = y_true.var(axis=0)
icond = 0
if len(idx) > 0:
for pert_category_predict in pert_category_list:
if '+' in pert_category_predict:
pert1, pert2 = pert_category_predict.split('+')
idx_pred1 = np.where(datasets['training'].pert_categories ==\
pert1)[0]
idx_pred2 = np.where(datasets['training'].pert_categories ==\
pert2)[0]
y_pred1 = datasets['training'].genes[idx_pred1, :].numpy()
y_pred2 = datasets['training'].genes[idx_pred2, :].numpy()
x1 = float(pert1.split('_')[2])
x2 = float(pert2.split('_')[2])
x = float(pert_category.split('_')[2])
yp_m1 = y_pred1.mean(axis=0)
yp_m2 = y_pred2.mean(axis=0)
yp_v1 = y_pred1.var(axis=0)
yp_v2 = y_pred2.var(axis=0)
yp_m = linear_interp(yp_m1, yp_m2, x1, x2, x)
yp_v = linear_interp(yp_v1, yp_v2, x1, x2, x)
# yp_m = (y_pred1.mean(axis=0) + y_pred2.mean(axis=0))/2
# yp_v = (y_pred1.var(axis=0) + y_pred2.var(axis=0))/2
else:
idx_pred = np.where(datasets['training'].pert_categories ==\
pert_category_predict)[0]
print(pert_category_predict, len(idx_pred))
y_pred = datasets['training'].genes[idx_pred, :].numpy()
# predicted means and variances
yp_m = y_pred.mean(axis=0)
yp_v = y_pred.var(axis=0)
mean_score = r2_score(yt_m, yp_m)
var_score = r2_score(yt_v, yp_v)
mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx])
var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx])
scores.loc[icond] = pert_category.split('_') +\
[mean_score, mean_score_de, var_score, var_score_de,\
len(idx), pert_category_predict, 'benchmark']
icond += 1
return scores