in compert/model_selection.py [0:0]
def get_best_plots(model_name, path='./results/plots'):
print('Start plotting for:', model_name)
specs = DatasetSpecs(model_name)
folder = f"{path}/{model_name.split('/')[-2]}/"
Path(folder).mkdir(parents=True, exist_ok=True)
plots_prefix = f"{folder}/{model_name.split('/')[-2]}_{model_name.split('/')[-1]}"
print('Plots are saved to: ', plots_prefix + '_*')
# load model weights
state, args, history = torch.load(
model_name, map_location=torch.device('cpu'))
# Plot training history
pretty_history = ComPertHistory(history, fileprefix=plots_prefix)
pretty_history.print_time()
pretty_history.plot_losses()
pretty_history.plot_metrics(epoch_min=100)
# Load the dataset and model pre-trained weights
autoencoder, datasets = prepare_compert(args, state_dict=state)
# Setting a variable for the API
compert_api = ComPertAPI(datasets, autoencoder)
# Setting up a variabel for automatic plotting. The plots also could be
# used separately.
compert_plots = CompertVisuals(compert_api, fileprefix=plots_prefix,
perts_palette=specs.perts_palette)
# Plot latent space
perts_anndata = compert_api.get_drug_embeddings()
covars_anndata = compert_api.get_covars_embeddings()
compert_plots.plot_latent_embeddings(
compert_api.emb_perts,
kind='perturbations',
show_text=True)
compert_plots.plot_latent_embeddings(compert_api.emb_covars, kind='covars')
# Plot latent dose response
latent_response = compert_api.latent_dose_response(perturbations=None)
compert_plots.plot_contvar_response(
latent_response,
postfix='latent',
var_name=compert_api.perturbation_key,
title_name='Latent dose response')
# Plot latent dose response 2D
if not (specs.perturbations_pair is None):
latent_dose_2D = compert_api.latent_dose_response2D(
specs.perturbations_pair, n_points=100)
compert_plots.plot_contvar_response2D(
latent_dose_2D,
postfix='latent2D',
title_name='Latent dose-response')
reconstructed_response2D = compert_api.get_response2D(
datasets, specs.perturbations_pair, compert_api.unique_сovars[0])
compert_plots.plot_contvar_response2D(reconstructed_response2D,
title_name='Reconstructed dose-response 2D',
logdose=False,
postfix='reconstructed-dose-response2D',
# xlims=(-3, 0), ylims=(-3, 0)
)
compert_plots.plot_contvar_response2D(reconstructed_response2D,
title_name='Reconstructed dose-response 2D',
logdose=True,
postfix='log10-reconstructed-dose-response2D',
xlims=(-3, 0), ylims=(-3, 0)
)
df_pred = pl.plot_uncertainty_comb_dose(
compert_api=compert_api,
cov=specs.selected_cov,
pert=f'{specs.perturbations_pair[0]}+{specs.perturbations_pair[1]}',
N=51,
cond_key='treatment',
filename=f'{compert_plots.fileprefix}_uncertainty_{specs.perturbations_pair[0]}_{specs.perturbations_pair[1]}.png',
metric='cosine',
)
uncert_list = []
for i, drug in enumerate(specs.selected_drugs):
uncert_list.append(pl.plot_uncertainty_dose(
compert_api,
cov=specs.selected_cov,
pert=drug,
N=51,
measured_points=compert_api.measured_points['all'],
cond_key='condition',
log=True,
metric='cosine',
filename=f'{compert_plots.fileprefix}_uncertainty_{drug}.png',
))
df_uncert = pd.concat(uncert_list)
selected_drug = specs.selected_drugs[0]
logscale_labels = compert_api.measured_points['all'][specs.selected_cov][selected_drug]
df_ref = get_reference_from_combo([selected_drug], datasets)
df_ref['uncertainty_cosine'] = 0
df_ref['uncertainty_eucl'] = 0
df_ref['condition'] = selected_drug
df_ref['log10-dose'] = [np.log10(float(d))
for d in df_ref[selected_drug].values]
df_uncert['log10-dose'] = [np.log10(float(d))
for d in df_uncert['dose_val'].values]
for unc in ['uncertainty_cosine', 'uncertainty_eucl']:
pl.plot_dose_response(df_uncert,
'log10-dose',
'condition',
xlabelname='log10-dose',
df_ref=df_ref,
response_name=unc,
title_name='',
use_ref_response=True,
col_dict=compert_plots.perts_palette,
plot_vertical=False,
f1=4,
f2=3.3,
logscale=logscale_labels,
fname=f'{plots_prefix}_{unc}',
bbox=(1.6, 1.),
fontsize=13,
format='png')
# # Plot reconstructed dose response
if specs.plot_ref:
df_reference = compert_api.get_response_reference(datasets)
reconstructed_response = compert_api.get_response(datasets)
# df_reference = df_reference.replace('training_treated', 'train')
for gene in specs.target_genes:
compert_plots.plot_contvar_response(
reconstructed_response,
df_ref=df_reference,
postfix='reconstructed-dose-response',
figsize=(4, 3.3),
bbox=(1.6, 1.),
response_name=gene,
xlabelname='dose',
logdose=False,
palette=compert_plots.perts_palette,
title_name='')
compert_plots.plot_contvar_response(
reconstructed_response,
postfix='log10-reconstructed-dose-response',
df_ref=df_reference,
figsize=(4, 3.3),
bbox=(1.6, 1.),
response_name=gene,
xlabelname='log10-dose',
logdose=True,
palette=compert_plots.perts_palette,
measured_points=logscale_labels,
title_name='')