# Analyzing inference pipeline runs


As an AlphaFold inference pipeline executes inference workflow steps, information about a step's outcome, including artifacts and artifact metadata generated by a step, is tracked in Vertex ML Metadata. For example, a model prediction step tracks the locations of raw prediction and protein structure files and properties like ranking confidence. The tracked information ensures reproducibility and supports detailed run analysis including lineage tracing. 


In this notebook, you will explore how to retrieve pipeline run metadata from Vertex AI Metadata, explore its properties and visualize generated artifacts.

### Install and import required packages

In [None]:
%cd /home/jupyter/vertex-ai-alphafold-inference-pipeline/src
%pip install .
%cd /home/jupyter/vertex-ai-alphafold-inference-pipeline


In [None]:
! pip install py3dmol
! pip install dm-tree
! pip install matplotlib
! pip install biopython
! pip install ipywidgets
! pip install jax
! pip install jaxlib
! pip install dm-haiku

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import py3Dmol
from IPython import display
from ipywidgets import GridspecLayout
from ipywidgets import Output
from google.cloud import aiplatform_v1 as vertex_ai
from google.cloud import storage
from pathlib import Path

from src.analysis import notebook_utils
from src.analysis import parsers
from src.analysis import utils

## Pipeline metadata exploration

In this section of the notebook, you will use Vertex AI SDK to retrieve information about pipeline runs, including information about artifacts and metadata tracked by a run.

### Define API clients to interact with Vertex AI Pipelines metadata

Change the variables in the next cell according to your environment's definition.

In [None]:
PROJECT_ID = 'YOUR PROJECT ID'    # Replace with your project ID
REGION = 'YOUR REGION'  # Replace with your region
API_ENDPOINT = "{}-aiplatform.googleapis.com".format(REGION)

# Create Pipelines and Metadata service clients
client_pipeline = vertex_ai.PipelineServiceClient(client_options={"api_endpoint": API_ENDPOINT})
client_metadata = vertex_ai.MetadataServiceClient(client_options={"api_endpoint": API_ENDPOINT})

### Listing all running pipelines

You can list all pipelines in a given state. For example all pipelines that are still running.  
The following cell creates a request to list all the running Pipelines.

In [None]:
FILTER = 'state="PIPELINE_STATE_RUNNING"'

list_pipelines_request = vertex_ai.ListPipelineJobsRequest(
    parent=f'projects/{PROJECT_ID}/locations/{REGION}',
    filter=FILTER
)

list_pipelines = list(client_pipeline.list_pipeline_jobs(list_pipelines_request))

if list_pipelines:
    print(f'There are {len(list_pipelines)} pipeline(s) running.')
    for pipeline in list_pipelines:
        print(f'ID: {pipeline.name} - State: {pipeline.state.name}')
else:
    print('No pipelines running.')

### Listing all pipelines with a specific label

You can search for all pipelines annotated with a specific label. For example, you can list all pipelines that were grouped under a given experiment.

Set the `experiment_id` variable to a label you used to a annotate pipeline run for the monomer optimized pipeline.

In [None]:
experiment_id = 'YOUR LABEL'  # Replace with your label (lower case)
FILTER = f'labels.experiment_id="{experiment_id}"'

list_pipelines_request = vertex_ai.ListPipelineJobsRequest(
    parent=f'projects/{PROJECT_ID}/locations/{REGION}',
    filter=FILTER
)

list_pipelines = list(client_pipeline.list_pipeline_jobs(list_pipelines_request))

print(f'Number of Pipeline(s) found: {len(list_pipelines)}')

if list_pipelines:
    print(f'Printing pipeline(s) with label `experiment_id` equals to `{experiment_id}`.')
    for pipeline in list_pipelines:
        print(f'Id: {pipeline.name} - State: {pipeline.state.name}')
else:
    print('No pipelines found.')

### Retrieving artifact metadata

You can retrieve metadata for artifacts generated by a pipeline. For example, you can retrieve ranking confidence from a raw prediction artifact generated by a model predict step.

You will construct a data structure containing all information generated by a pipeline.

You need to use a full ID of the pipeline in the format `projects/<PROJECT NUMBER>/locations/<REGION>/pipelineJobs/<PIPELINE NAME>`. Set the `pipeline_id` variable to the ID of a successful run returned by the previous cell.

**IMPORTANT**: Please change the variables `pipeline_id` and `is_monomer` according to the definitions of the pipeline you chose.

In [None]:
# Change to your pipeline ID
pipeline_id = 'YOUR PIPELINE ID'

# Change to False if multimer
is_monomer = True

In [None]:
if is_monomer:
  model_type_to_use = notebook_utils.ModelType.MONOMER
else:
  model_type_to_use = notebook_utils.ModelType.MULTIMER

get_request = vertex_ai.GetPipelineJobRequest(
    name=pipeline_id
)
pipeline_job = client_pipeline.get_pipeline_job(get_request)

After you have the data structure you can browse it and retrieve the required properties. In the following example, you will retrieve the `ranking_confidence` values tracked by model prediction steps.

In [None]:
# Retrieve all tasks with name = predict
predict_tasks = [i for i in pipeline_job.job_detail.task_details if i.task_name == 'predict']
formated_predict_tasks = []

for t in predict_tasks:
    task_id = t.task_id
    parent_task_id = t.parent_task_id
    model_name = t.execution.metadata['input:model_name']
    ranking_confidence = t.outputs['raw_prediction'].artifacts[0].metadata['ranking_confidence']
    uri = t.outputs['raw_prediction'].artifacts[0].uri

    formated_predict_tasks.append(
        {
            'task_id': task_id,
            'parent_task_id': parent_task_id,
            'model_name': model_name,
            'ranking_confidence': ranking_confidence,
            'uri': uri
        }
    )

# Sort predict tasks by ranking confidence
sorted_predict = sorted(formated_predict_tasks, key=lambda x: x['ranking_confidence'], reverse=True)

In [None]:
print('Preditions ranking:')
for prediction in sorted_predict:
    print(prediction['model_name'], '=>', 'Ranking confidence:', prediction['ranking_confidence'])

Artifacts generated by Alphafold pipelines have a metadata property called `category`, which helps with artifact discovery and metadata retrieval.  
Let's find all the Artifacts with the `category` property set to `msa`. These artifacts capture information about MSAs generated during feature engineering phase.

In [None]:
pipeline_ctx = pipeline_job.job_detail.pipeline_run_context.name
parent = f'projects/{PROJECT_ID}/locations/{REGION}/metadataStores/default'

# Filter which artifacts to present
FILTER = f'in_context("{pipeline_ctx}") AND display_name="msas"'

list_artifacts_request = vertex_ai.ListArtifactsRequest(
    parent=parent,
    filter=FILTER
)

msas = list(client_metadata.list_artifacts(request=list_artifacts_request))

You can now display all the properties associated with MSA artifacts.  
In this case you are reading the last file found with `bfd` in its name.

In [None]:
bucket_name = msas[0].uri.replace('gs://', '').split(sep='/')[0]
blob_prefix = '/'.join(msas[0].uri.replace('gs://', '').split(sep='/')[1:])

storage_client = storage.Client()
blobs = storage_client.list_blobs(bucket_name, prefix=blob_prefix+'/')

for b in blobs:
    if 'bfd' in b.name:
        msa_uri = 'gs://' + bucket_name + '/' + b.name
        msa_filename = Path(msa_uri)

Since an artifact entry in Vertex Metadata contains a link to the location of artifact file(s) in Google Cloud storage, you can retrieve the artifact for further, local analysis.

For example, you can download an MSA file and use AlphaFold notebook widgets to visualize it.

In [None]:
! gsutil cp {msa_uri} .

The next cell can take up to 30 seconds to execute.

In [None]:
with open(msa_filename.name, 'r') as fp:
    msa_file = fp.read()

if msa_filename.suffix == '.a3m':
    msa_parsed = parsers.parse_a3m(msa_file)
elif msa_filename.suffix == '.sto':
    msa_parsed = parsers.parse_stockholm(msa_file)
elif msa_filename.suffix == '.hhr':
    msa_parsed = parsers.parse_hhr(msa_file)
notebook_utils.show_msa_info([msa_parsed], 0)

# Visualize Prediction

Next you will visualise the prediction & confidence.
You need to download two files for this step:
 - Prediction pickle file
 - Relaxed protein PDB file

If the multimer model has been used, it will show the structure coloured by chain.

In [None]:
# Find the `Relax` task to get its URI
for t in pipeline_job.job_detail.task_details:
    if t.parent_task_id == sorted_predict[0]['parent_task_id'] and 'condition' in t.task_name:
        condition_relax_parent_id = t.task_id
        break

for t in pipeline_job.job_detail.task_details:
    if t.parent_task_id == condition_relax_parent_id:
        top_relax_predict = t.outputs['relaxed_protein'].artifacts[0].uri
        break

In [None]:
! gsutil cp {sorted_predict[0]["uri"]} .
! gsutil cp {top_relax_predict} .

In [None]:
with open('raw_prediction.pkl', 'rb') as fp:
  raw_predictions = pickle.load(fp)

with open('relaxed_protein.pdb', 'r') as fp:
  relaxed_pdb = fp.read()

# Color bands for visualizing plddt
PLDDT_BANDS = [(0, 50, '#FF7D45'),
               (50, 70, '#FFDB13'),
               (70, 90, '#65CBF3'),
               (90, 100, '#0053D6')]

banded_b_factors = []
final_atom_mask = raw_predictions['structure_module']['final_atom_mask']

for plddt in raw_predictions['plddt']:
  for idx, (min_val, max_val, _) in enumerate(PLDDT_BANDS):
    if plddt >= min_val and plddt <= max_val:
      banded_b_factors.append(idx)
      break
banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask
to_visualize_pdb = utils.overwrite_b_factors(relaxed_pdb, banded_b_factors)

show_sidechains = True
def plot_plddt_legend():
  """Plots the legend for pLDDT."""
  thresh = ['Very low (pLDDT < 50)',
            'Low (70 > pLDDT > 50)',
            'Confident (90 > pLDDT > 70)',
            'Very high (pLDDT > 90)']

  colors = [x[2] for x in PLDDT_BANDS]

  plt.figure(figsize=(2, 2))
  for c in colors:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False, loc='center', fontsize=20)
  plt.xticks([])
  plt.yticks([])
  ax = plt.gca()
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  ax.spines['left'].set_visible(False)
  ax.spines['bottom'].set_visible(False)
  plt.title('Model Confidence', fontsize=20, pad=20)
  return plt

# Show the structure coloured by chain if the multimer model has been used.
if model_type_to_use == notebook_utils.ModelType.MULTIMER:
  multichain_view = py3Dmol.view(width=800, height=600)
  multichain_view.addModelsAsFrames(to_visualize_pdb)
  multichain_style = {'cartoon': {'colorscheme': 'chain'}}
  multichain_view.setStyle({'model': -1}, multichain_style)
  multichain_view.zoomTo()
  multichain_view.show()

# Color the structure by per-residue pLDDT
color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(to_visualize_pdb)
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}
if show_sidechains:
  style['stick'] = {}
view.setStyle({'model': -1}, style)
view.zoomTo()

grid = GridspecLayout(1, 2)
out = Output()
with out:
  view.show()
grid[0, 0] = out

out = Output()
with out:
  plot_plddt_legend().show()
grid[0, 1] = out

display.display(grid)

if 'predicted_aligned_error' in raw_predictions:
  num_plots = 2
  pae = raw_predictions['predicted_aligned_error']
  max_pae = raw_predictions['max_predicted_aligned_error']
else:
  num_plots = 1

plt.figure(figsize=[8 * num_plots, 6])
plt.subplot(1, num_plots, 1)
plt.plot(raw_predictions['plddt'])
plt.title('Predicted LDDT')
plt.xlabel('Residue')
plt.ylabel('pLDDT')

if num_plots == 2:
  plt.subplot(1, 2, 2)
  plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')
  plt.colorbar(fraction=0.046, pad=0.04)
  plt.title('Predicted Aligned Error')
  plt.xlabel('Scored residue')
  plt.ylabel('Aligned residue')