2-alphafold-metadata-exploration.ipynb (570 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Analyzing inference pipeline runs\n",
"\n",
"\n",
"As an AlphaFold inference pipeline executes inference workflow steps, information about a step's outcome, including artifacts and artifact metadata generated by a step, is tracked in Vertex ML Metadata. For example, a model prediction step tracks the locations of raw prediction and protein structure files and properties like ranking confidence. The tracked information ensures reproducibility and supports detailed run analysis including lineage tracing. \n",
"\n",
"\n",
"In this notebook, you will explore how to retrieve pipeline run metadata from Vertex AI Metadata, explore its properties and visualize generated artifacts."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Install and import required packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%cd /home/jupyter/vertex-ai-alphafold-inference-pipeline/src\n",
"%pip install .\n",
"%cd /home/jupyter/vertex-ai-alphafold-inference-pipeline\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install py3dmol\n",
"! pip install dm-tree\n",
"! pip install matplotlib\n",
"! pip install biopython\n",
"! pip install ipywidgets\n",
"! pip install jax\n",
"! pip install jaxlib\n",
"! pip install dm-haiku"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pickle\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import py3Dmol\n",
"from IPython import display\n",
"from ipywidgets import GridspecLayout\n",
"from ipywidgets import Output\n",
"from google.cloud import aiplatform_v1 as vertex_ai\n",
"from google.cloud import storage\n",
"from pathlib import Path\n",
"\n",
"from src.analysis import notebook_utils\n",
"from src.analysis import parsers\n",
"from src.analysis import utils"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pipeline metadata exploration\n",
"\n",
"In this section of the notebook, you will use Vertex AI SDK to retrieve information about pipeline runs, including information about artifacts and metadata tracked by a run."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define API clients to interact with Vertex AI Pipelines metadata"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Change the variables in the next cell according to your environment's definition."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PROJECT_ID = 'YOUR PROJECT ID' # Replace with your project ID\n",
"REGION = 'YOUR REGION' # Replace with your region\n",
"API_ENDPOINT = \"{}-aiplatform.googleapis.com\".format(REGION)\n",
"\n",
"# Create Pipelines and Metadata service clients\n",
"client_pipeline = vertex_ai.PipelineServiceClient(client_options={\"api_endpoint\": API_ENDPOINT})\n",
"client_metadata = vertex_ai.MetadataServiceClient(client_options={\"api_endpoint\": API_ENDPOINT})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Listing all running pipelines\n",
"\n",
"You can list all pipelines in a given state. For example all pipelines that are still running. \n",
"The following cell creates a request to list all the running Pipelines."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"FILTER = 'state=\"PIPELINE_STATE_RUNNING\"'\n",
"\n",
"list_pipelines_request = vertex_ai.ListPipelineJobsRequest(\n",
" parent=f'projects/{PROJECT_ID}/locations/{REGION}',\n",
" filter=FILTER\n",
")\n",
"\n",
"list_pipelines = list(client_pipeline.list_pipeline_jobs(list_pipelines_request))\n",
"\n",
"if list_pipelines:\n",
" print(f'There are {len(list_pipelines)} pipeline(s) running.')\n",
" for pipeline in list_pipelines:\n",
" print(f'ID: {pipeline.name} - State: {pipeline.state.name}')\n",
"else:\n",
" print('No pipelines running.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Listing all pipelines with a specific label\n",
"\n",
"You can search for all pipelines annotated with a specific label. For example, you can list all pipelines that were grouped under a given experiment.\n",
"\n",
"Set the `experiment_id` variable to a label you used to a annotate pipeline run for the monomer optimized pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"experiment_id = 'YOUR LABEL' # Replace with your label (lower case)\n",
"FILTER = f'labels.experiment_id=\"{experiment_id}\"'\n",
"\n",
"list_pipelines_request = vertex_ai.ListPipelineJobsRequest(\n",
" parent=f'projects/{PROJECT_ID}/locations/{REGION}',\n",
" filter=FILTER\n",
")\n",
"\n",
"list_pipelines = list(client_pipeline.list_pipeline_jobs(list_pipelines_request))\n",
"\n",
"print(f'Number of Pipeline(s) found: {len(list_pipelines)}')\n",
"\n",
"if list_pipelines:\n",
" print(f'Printing pipeline(s) with label `experiment_id` equals to `{experiment_id}`.')\n",
" for pipeline in list_pipelines:\n",
" print(f'Id: {pipeline.name} - State: {pipeline.state.name}')\n",
"else:\n",
" print('No pipelines found.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Retrieving artifact metadata\n",
"\n",
"You can retrieve metadata for artifacts generated by a pipeline. For example, you can retrieve ranking confidence from a raw prediction artifact generated by a model predict step.\n",
"\n",
"You will construct a data structure containing all information generated by a pipeline.\n",
"\n",
"You need to use a full ID of the pipeline in the format `projects/<PROJECT NUMBER>/locations/<REGION>/pipelineJobs/<PIPELINE NAME>`. Set the `pipeline_id` variable to the ID of a successful run returned by the previous cell.\n",
"\n",
"**IMPORTANT**: Please change the variables `pipeline_id` and `is_monomer` according to the definitions of the pipeline you chose."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Change to your pipeline ID\n",
"pipeline_id = 'YOUR PIPELINE ID'\n",
"\n",
"# Change to False if multimer\n",
"is_monomer = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if is_monomer:\n",
" model_type_to_use = notebook_utils.ModelType.MONOMER\n",
"else:\n",
" model_type_to_use = notebook_utils.ModelType.MULTIMER\n",
"\n",
"get_request = vertex_ai.GetPipelineJobRequest(\n",
" name=pipeline_id\n",
")\n",
"pipeline_job = client_pipeline.get_pipeline_job(get_request)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After you have the data structure you can browse it and retrieve the required properties. In the following example, you will retrieve the `ranking_confidence` values tracked by model prediction steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Retrieve all tasks with name = predict\n",
"predict_tasks = [i for i in pipeline_job.job_detail.task_details if i.task_name == 'predict']\n",
"formated_predict_tasks = []\n",
"\n",
"for t in predict_tasks:\n",
" task_id = t.task_id\n",
" parent_task_id = t.parent_task_id\n",
" model_name = t.execution.metadata['input:model_name']\n",
" ranking_confidence = t.outputs['raw_prediction'].artifacts[0].metadata['ranking_confidence']\n",
" uri = t.outputs['raw_prediction'].artifacts[0].uri\n",
"\n",
" formated_predict_tasks.append(\n",
" {\n",
" 'task_id': task_id,\n",
" 'parent_task_id': parent_task_id,\n",
" 'model_name': model_name,\n",
" 'ranking_confidence': ranking_confidence,\n",
" 'uri': uri\n",
" }\n",
" )\n",
"\n",
"# Sort predict tasks by ranking confidence\n",
"sorted_predict = sorted(formated_predict_tasks, key=lambda x: x['ranking_confidence'], reverse=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('Preditions ranking:')\n",
"for prediction in sorted_predict:\n",
" print(prediction['model_name'], '=>', 'Ranking confidence:', prediction['ranking_confidence'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Artifacts generated by Alphafold pipelines have a metadata property called `category`, which helps with artifact discovery and metadata retrieval. \n",
"Let's find all the Artifacts with the `category` property set to `msa`. These artifacts capture information about MSAs generated during feature engineering phase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipeline_ctx = pipeline_job.job_detail.pipeline_run_context.name\n",
"parent = f'projects/{PROJECT_ID}/locations/{REGION}/metadataStores/default'\n",
"\n",
"# Filter which artifacts to present\n",
"FILTER = f'in_context(\"{pipeline_ctx}\") AND display_name=\"msas\"'\n",
"\n",
"list_artifacts_request = vertex_ai.ListArtifactsRequest(\n",
" parent=parent,\n",
" filter=FILTER\n",
")\n",
"\n",
"msas = list(client_metadata.list_artifacts(request=list_artifacts_request))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can now display all the properties associated with MSA artifacts. \n",
"In this case you are reading the last file found with `bfd` in its name."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bucket_name = msas[0].uri.replace('gs://', '').split(sep='/')[0]\n",
"blob_prefix = '/'.join(msas[0].uri.replace('gs://', '').split(sep='/')[1:])\n",
"\n",
"storage_client = storage.Client()\n",
"blobs = storage_client.list_blobs(bucket_name, prefix=blob_prefix+'/')\n",
"\n",
"for b in blobs:\n",
" if 'bfd' in b.name:\n",
" msa_uri = 'gs://' + bucket_name + '/' + b.name\n",
" msa_filename = Path(msa_uri)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since an artifact entry in Vertex Metadata contains a link to the location of artifact file(s) in Google Cloud storage, you can retrieve the artifact for further, local analysis.\n",
"\n",
"For example, you can download an MSA file and use AlphaFold notebook widgets to visualize it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! gsutil cp {msa_uri} ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next cell can take up to 30 seconds to execute."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(msa_filename.name, 'r') as fp:\n",
" msa_file = fp.read()\n",
"\n",
"if msa_filename.suffix == '.a3m':\n",
" msa_parsed = parsers.parse_a3m(msa_file)\n",
"elif msa_filename.suffix == '.sto':\n",
" msa_parsed = parsers.parse_stockholm(msa_file)\n",
"elif msa_filename.suffix == '.hhr':\n",
" msa_parsed = parsers.parse_hhr(msa_file)\n",
"notebook_utils.show_msa_info([msa_parsed], 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize Prediction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next you will visualise the prediction & confidence.\n",
"You need to download two files for this step:\n",
" - Prediction pickle file\n",
" - Relaxed protein PDB file\n",
"\n",
"If the multimer model has been used, it will show the structure coloured by chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Find the `Relax` task to get its URI\n",
"for t in pipeline_job.job_detail.task_details:\n",
" if t.parent_task_id == sorted_predict[0]['parent_task_id'] and 'condition' in t.task_name:\n",
" condition_relax_parent_id = t.task_id\n",
" break\n",
"\n",
"for t in pipeline_job.job_detail.task_details:\n",
" if t.parent_task_id == condition_relax_parent_id:\n",
" top_relax_predict = t.outputs['relaxed_protein'].artifacts[0].uri\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! gsutil cp {sorted_predict[0][\"uri\"]} .\n",
"! gsutil cp {top_relax_predict} ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('raw_prediction.pkl', 'rb') as fp:\n",
" raw_predictions = pickle.load(fp)\n",
"\n",
"with open('relaxed_protein.pdb', 'r') as fp:\n",
" relaxed_pdb = fp.read()\n",
"\n",
"# Color bands for visualizing plddt\n",
"PLDDT_BANDS = [(0, 50, '#FF7D45'),\n",
" (50, 70, '#FFDB13'),\n",
" (70, 90, '#65CBF3'),\n",
" (90, 100, '#0053D6')]\n",
"\n",
"banded_b_factors = []\n",
"final_atom_mask = raw_predictions['structure_module']['final_atom_mask']\n",
"\n",
"for plddt in raw_predictions['plddt']:\n",
" for idx, (min_val, max_val, _) in enumerate(PLDDT_BANDS):\n",
" if plddt >= min_val and plddt <= max_val:\n",
" banded_b_factors.append(idx)\n",
" break\n",
"banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n",
"to_visualize_pdb = utils.overwrite_b_factors(relaxed_pdb, banded_b_factors)\n",
"\n",
"show_sidechains = True\n",
"def plot_plddt_legend():\n",
" \"\"\"Plots the legend for pLDDT.\"\"\"\n",
" thresh = ['Very low (pLDDT < 50)',\n",
" 'Low (70 > pLDDT > 50)',\n",
" 'Confident (90 > pLDDT > 70)',\n",
" 'Very high (pLDDT > 90)']\n",
"\n",
" colors = [x[2] for x in PLDDT_BANDS]\n",
"\n",
" plt.figure(figsize=(2, 2))\n",
" for c in colors:\n",
" plt.bar(0, 0, color=c)\n",
" plt.legend(thresh, frameon=False, loc='center', fontsize=20)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" ax = plt.gca()\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" ax.spines['left'].set_visible(False)\n",
" ax.spines['bottom'].set_visible(False)\n",
" plt.title('Model Confidence', fontsize=20, pad=20)\n",
" return plt\n",
"\n",
"# Show the structure coloured by chain if the multimer model has been used.\n",
"if model_type_to_use == notebook_utils.ModelType.MULTIMER:\n",
" multichain_view = py3Dmol.view(width=800, height=600)\n",
" multichain_view.addModelsAsFrames(to_visualize_pdb)\n",
" multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n",
" multichain_view.setStyle({'model': -1}, multichain_style)\n",
" multichain_view.zoomTo()\n",
" multichain_view.show()\n",
"\n",
"# Color the structure by per-residue pLDDT\n",
"color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n",
"view = py3Dmol.view(width=800, height=600)\n",
"view.addModelsAsFrames(to_visualize_pdb)\n",
"style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}\n",
"if show_sidechains:\n",
" style['stick'] = {}\n",
"view.setStyle({'model': -1}, style)\n",
"view.zoomTo()\n",
"\n",
"grid = GridspecLayout(1, 2)\n",
"out = Output()\n",
"with out:\n",
" view.show()\n",
"grid[0, 0] = out\n",
"\n",
"out = Output()\n",
"with out:\n",
" plot_plddt_legend().show()\n",
"grid[0, 1] = out\n",
"\n",
"display.display(grid)\n",
"\n",
"if 'predicted_aligned_error' in raw_predictions:\n",
" num_plots = 2\n",
" pae = raw_predictions['predicted_aligned_error']\n",
" max_pae = raw_predictions['max_predicted_aligned_error']\n",
"else:\n",
" num_plots = 1\n",
"\n",
"plt.figure(figsize=[8 * num_plots, 6])\n",
"plt.subplot(1, num_plots, 1)\n",
"plt.plot(raw_predictions['plddt'])\n",
"plt.title('Predicted LDDT')\n",
"plt.xlabel('Residue')\n",
"plt.ylabel('pLDDT')\n",
"\n",
"if num_plots == 2:\n",
" plt.subplot(1, 2, 2)\n",
" plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n",
" plt.colorbar(fraction=0.046, pad=0.04)\n",
" plt.title('Predicted Aligned Error')\n",
" plt.xlabel('Scored residue')\n",
" plt.ylabel('Aligned residue')"
]
}
],
"metadata": {
"environment": {
"kernel": "python3",
"name": "common-cpu.m102",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cpu:m102"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:35) \n[GCC 10.4.0]"
},
"vscode": {
"interpreter": {
"hash": "af0757ec56c03d1c123f7ed927d527de59726097999d70188569e1fb3d3de9b2"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}