in compert/api.py [0:0]
def predict(
self,
genes,
df,
uncertainty=True,
return_anndata=True,
sample=False,
n_samples=10
):
"""Predict values of control 'genes' conditions specified in df.
Parameters
----------
genes : np.array
Control cells.
df : pd.DataFrame
Values for perturbations and covariates to generate.
uncertainty: bool (default: True)
Compute uncertainties for the generated cells.
return_anndata : bool, optional (default: True)
Return embedding wrapped into anndata object.
sample : bool (default: False)
If sample is True, returns samples from gausssian distribution with
mean and variance estimated by the model. Otherwise, returns just
means and variances estimated by the model.
n_samples : int (default: 10)
Number of samples to sample if sampling is True.
Returns
-------
If return_anndata is True, returns anndata structure. Otherwise, returns
np.arrays for gene_means, gene_vars and a data frame for the corresponding
conditions df_obs.
"""
self.model.eval()
num = genes.shape[0]
dim = genes.shape[1]
genes = torch.Tensor(genes).to(self.model.device)
if sample:
print('Careful! These are sampled values! Better use means and \
variances for dowstream tasks!')
gene_means_list = []
gene_vars_list = []
df_list = []
for i in range(len(df)):
comb_name = df.loc[i][self.perturbation_key]
dose_name = df.loc[i][self.dose_key]
covar_name = df.loc[i][self.covars_key]
covar_ohe = torch.Tensor(
self.covars_dict[covar_name]
).to(self.model.device)
drug_ohe = torch.Tensor(
self.get_drug_encoding_(
comb_name,
doses=dose_name
)
).to(self.model.device)
drugs = drug_ohe.expand([num, self.drug_ohe.shape[1]])
covars = covar_ohe.expand([num, self.covars_ohe.shape[1]])
gene_reconstructions = self.model.predict(
genes,
drugs,
covars
).cpu().clone().detach().numpy()
if sample:
df_list.append(
pd.DataFrame(
[df.loc[i].values]*num*n_samples,
columns=df.columns
)
)
dist = torch.distributions.normal.Normal(
torch.Tensor(gene_reconstructions[:, :dim]),
torch.Tensor(gene_reconstructions[:, dim:]),
)
gene_means_list.append(
dist
.sample(torch.Size([n_samples]))
.cpu()
.detach()
.numpy()
.reshape(-1, dim)
)
else:
df_list.append(
pd.DataFrame(
[df.loc[i].values]*num,
columns=df.columns
)
)
gene_means_list.append(
gene_reconstructions[:, :dim]
)
if uncertainty:
cos_dist, eucl_dist, closest_cond_cos, closest_cond_eucl =\
self.compute_uncertainty(
cov=covar_name,
pert=comb_name,
dose=dose_name
)
df_list[-1] = df_list[-1].assign(
uncertainty_cosine=cos_dist,
uncertainty_euclidean=eucl_dist,
closest_cond_cosine=closest_cond_cos,
closest_cond_euclidean=closest_cond_eucl
)
gene_vars_list.append(
gene_reconstructions[:, dim:]
)
gene_means = np.concatenate(gene_means_list)
gene_vars = np.concatenate(gene_vars_list)
df_obs = pd.concat(df_list)
del df_list, gene_means_list, gene_vars_list
if return_anndata:
adata = sc.AnnData(gene_means)
adata.var_names = self.var_names
adata.obs = df_obs
if not sample:
adata.layers["variance"] = gene_vars
adata.obs.index = adata.obs.index.astype(str) # type fix
return adata
else:
return gene_means, gene_vars, df_obs