# Low-latency item-to-item recommendation system - Orchestrating with TFX

## Overview

This notebook is a part of the series that describes the process of implementing a [**Low-latency item-to-item recommendation system**](https://github.com/jarokaz/analytics-componentized-patterns/tree/master/retail/recommendation-system/bqml-scann).

This notebook demonstrates how to use TFX and AI Platform Pipelines (Unified) to operationalize the workflow that creates embeddings and builds and deploys an ANN Service index. 

In the notebook you go through the following steps.

1. Creating TFX custom components that encapsulate operations on BQ, BQML and ANN Service.
2. Creating a TFX pipeline that automates the processes of creating embeddings and deploying an ANN Index 
3. Testing the pipeline locally using Beam runner.
4. Compiling the pipeline to the TFX IR format for execution on AI Platform Pipelines (Unified).
5. Submitting pipeline runs.

This notebook was designed to run on [AI Platform Notebooks](https://cloud.google.com/ai-platform-notebooks). Before running the notebook make sure that you have completed the setup steps as described in the [README file](README.md).

### TFX Pipeline Design

The below diagram depicts the TFX pipeline that you will implement in this notebook. Each step of the pipeline is implemented as a [TFX Custom Python function component](https://www.tensorflow.org/tfx/guide/custom_function_component). The components track the relevant metadata in AI Platform (Unfied) ML Metadata using both standard and custom metadata types. 

![TFX pipeline](figures/ann-tfx.png)

1. The first step of the pipeline is to compute item co-occurence. This is done by calling the `sp_ComputePMI` stored procedure created in the preceeding notebooks. 
2. Next, the BQML Matrix Factorization model is created. The model training code is encapsulated in the `sp_TrainItemMatchingModel` stored procedure.
3. Item embeddings are extracted from the trained model weights and stored in a BQ table. The component calls the `sp_ExtractEmbeddings` stored procedure that implements the extraction logic.
4. The embeddings are exported in the JSONL format to the GCS location using the BigQuery extract job.
5. The embeddings in the JSONL format are used to create an ANN index by calling the ANN Service Control Plane REST API.
6. Finally, the ANN index is deployed to an ANN endpoint.

All steps and their inputs and outputs are tracked in the AI Platform (Unified) ML Metadata service.


In [None]:
%load_ext autoreload
%autoreload 2

## Setting up the notebook's environment

### Install AI Platform Pipelines client library

For AI Platform Pipelines (Unified), which is in the Experimental stage, you need to download and install the AI Platform client library on top of the KFP and TFX SDKs that were installed as part of the initial environment setup.

In [None]:
AIP_CLIENT_WHEEL = 'aiplatform_pipelines_client-0.1.0.caip20201123-py3-none-any.whl'
AIP_CLIENT_WHEEL_GCS_LOCATION = f'gs://cloud-aiplatform-pipelines/releases/20201123/{AIP_CLIENT_WHEEL}'

In [None]:
!gsutil cp {AIP_CLIENT_WHEEL_GCS_LOCATION} {AIP_CLIENT_WHEEL}

In [None]:
%pip install {AIP_CLIENT_WHEEL}

Restart the kernel.

In [None]:
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

### Import notebook dependencies

In [None]:
import logging
import tfx
import tensorflow as tf

from aiplatform.pipelines import client
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner

print('TFX Version: ', tfx.__version__)

### Configure GCP environment

-----------------

**If you're on AI Platform Notebooks**, authenticate with Google Cloud before running the next section, by running
```sh
gcloud auth login
```
**in the Terminal window** (which you can open via **File** > **New** in the menu). You only need to do this once per notebook instance.

Set the following constants to the values reflecting your environment:

* `PROJECT_ID` - your GCP project ID
* `PROJECT_NUMBER` - your GCP project number
* `BUCKET_NAME` - a name of the GCS bucket that will be used to host artifacts created by the pipeline
* `PIPELINE_NAME_SUFFIX` - a suffix appended to the standard pipeline name. You can change to differentiate between pipelines from different users in a classroom environment
* `API_KEY` - a GCP API key
* `VPC_NAME` - a name of the GCP VPC to use for the index deployments.  
* `REGION` - a compute region. Don't change the default - `us-central` - while the ANN Service is in the experimental stage


In [None]:
PROJECT_ID = ''  # <---CHANGE THIS
PROJECT_NUMBER = ''  # <---CHANGE THIS
API_KEY = ''  # <---CHANGE THIS
USER = 'user' # <---CHANGE THIS
BUCKET_NAME = 'jk-ann-staging'  # <---CHANGE THIS
VPC_NAME = 'default' # <---CHANGE THIS IF USING A DIFFERENT VPC

REGION = 'us-central1'
PIPELINE_NAME = "ann-pipeline-{}".format(USER)
PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(BUCKET_NAME, PIPELINE_NAME)
PATH=%env PATH
%env PATH={PATH}:/home/jupyter/.local/bin
    
print('PIPELINE_ROOT: {}'.format(PIPELINE_ROOT))

## Defining custom components

In this section of the notebook you define a set of custom TFX components that encapsulate BQ, BQML and ANN Service calls. The components are [TFX Custom Python function components](https://www.tensorflow.org/tfx/guide/custom_function_component). 

Each component is created as a separate Python module. You also create a couple of helper modules that encapsulate Python functions and classess used across the custom components. 


### Remove files created in the previous executions of the notebook

In [None]:
component_folder = 'bq_components'

if tf.io.gfile.exists(component_folder):
    print('Removing older file')
    tf.io.gfile.rmtree(component_folder)
print('Creating component folder')
tf.io.gfile.mkdir(component_folder)

In [None]:
%cd {component_folder}

### Define custom types for ANN service artifacts

This module defines a couple of custom TFX artifacts to track ANN Service indexes and index deployments.

In [None]:
%%writefile ann_types.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom types for managing ANN artifacts."""

from tfx.types import artifact

class ANNIndex(artifact.Artifact):
    TYPE_NAME = 'ANNIndex'
    
class DeployedANNIndex(artifact.Artifact):
    TYPE_NAME = 'DeployedANNIndex'


### Create a wrapper around ANN Service REST API

This module provides a convenience wrapper around ANN Service REST API. In the experimental stage, the ANN Service does not have an "official" Python client SDK nor it is supported by the Google Discovery API.

In [None]:
%%writefile ann_service.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper classes encapsulating ANN Service REST API."""

import datetime
import logging
import json
import time

import google.auth

class ANNClient(object):
    """Base ANN Service client."""
    
    def __init__(self, project_id, project_number, region):
        credentials, _ = google.auth.default()
        self.authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
        self.ann_endpoint = f'{region}-aiplatform.googleapis.com'
        self.ann_parent = f'https://{self.ann_endpoint}/v1alpha1/projects/{project_id}/locations/{region}'
        self.project_id = project_id
        self.project_number = project_number
        self.region = region
        
    def wait_for_completion(self, operation_id, message, sleep_time):
        """Waits for a completion of a long running operation."""
        
        api_url = f'{self.ann_parent}/operations/{operation_id}'

        start_time = datetime.datetime.utcnow()
        while True:
            response = self.authed_session.get(api_url)
            if response.status_code != 200:
                raise RuntimeError(response.json())
            if 'done' in response.json().keys():
                logging.info('Operation completed!')
                break
            elapsed_time = datetime.datetime.utcnow() - start_time
            logging.info('{}. Elapsed time since start: {}.'.format(
                message, str(elapsed_time)))
            time.sleep(sleep_time)
    
        return response.json()['response']


class IndexClient(ANNClient):
    """Encapsulates a subset of control plane APIs 
    that manage ANN indexes."""

    def __init__(self, project_id, project_number, region):
        super().__init__(project_id, project_number, region)

    def create_index(self, display_name, description, metadata):
        """Creates an ANN Index."""
    
        api_url = f'{self.ann_parent}/indexes'
    
        request_body = {
            'display_name': display_name,
            'description': description,
            'metadata': metadata
        }
    
        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        operation_id = response.json()['name'].split('/')[-1]
        
        return operation_id

    def list_indexes(self, display_name=None):
        """Lists all indexes with a given display name or
        all indexes if the display_name is not provided."""
    
        if display_name:
            api_url = f'{self.ann_parent}/indexes?filter=display_name="{display_name}"'
        else:
            api_url = f'{self.ann_parent}/indexes'

        response = self.authed_session.get(api_url).json()

        return response['indexes'] if response else []
    
    def delete_index(self, index_id):
        """Deletes an ANN index."""
        
        api_url = f'{self.ann_parent}/indexes/{index_id}'
        response = self.authed_session.delete(api_url)
        if response.status_code != 200:
            raise RuntimeError(response.text)


class IndexDeploymentClient(ANNClient):
    """Encapsulates a subset of control plane APIs 
    that manage ANN endpoints and deployments."""
    
    def __init__(self, project_id, project_number, region):
        super().__init__(project_id, project_number, region)

    def create_endpoint(self, display_name, vpc_name):
        """Creates an ANN endpoint."""
    
        api_url = f'{self.ann_parent}/indexEndpoints'
        network_name = f'projects/{self.project_number}/global/networks/{vpc_name}'

        request_body = {
            'display_name': display_name,
            'network': network_name
        }

        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        operation_id = response.json()['name'].split('/')[-1]
    
        return operation_id
    
    def list_endpoints(self, display_name=None):
        """Lists all ANN endpoints with a given display name or
        all endpoints in the project if the display_name is not provided."""
        
        if display_name:
            api_url = f'{self.ann_parent}/indexEndpoints?filter=display_name="{display_name}"'
        else:
            api_url = f'{self.ann_parent}/indexEndpoints'

        response = self.authed_session.get(api_url).json()
 
        return response['indexEndpoints'] if response else []
    
    def delete_endpoint(self, endpoint_id):
        """Deletes an ANN endpoint."""
        
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}'
        
        response = self.authed_session.delete(api_url)
        if response.status_code != 200:
            raise RuntimeError(response.text)
        
        return response.json()
    
    def create_deployment(self, display_name, deployment_id, endpoint_id, index_id):
        """Deploys an ANN index to an endpoint."""
    
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}:deployIndex'
        index_name = f'projects/{self.project_number}/locations/{self.region}/indexes/{index_id}'

        request_body = {
            'deployed_index': {
                'id': deployment_id,
                'index': index_name,
                'display_name': display_name
            }
        }

        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        operation_id = response.json()['name'].split('/')[-1]
        
        return operation_id
    
    def get_deployment_grpc_ip(self, endpoint_id, deployment_id):
        """Returns a private IP address for a gRPC interface to 
        an Index deployment."""
  
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}'

        response = self.authed_session.get(api_url)
        if response.status_code != 200:
            raise RuntimeError(response.text)
            
        endpoint_ip = None
        if 'deployedIndexes' in response.json().keys():
            for deployment in response.json()['deployedIndexes']:
                if deployment['id'] == deployment_id:
                    endpoint_ip = deployment['privateEndpoints']['matchGrpcAddress']
                    
        return endpoint_ip

    
    def delete_deployment(self, endpoint_id, deployment_id):
        """Undeployes an index from an endpoint."""
        
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}:undeployIndex'
        
        request_body = {
            'deployed_index_id': deployment_id
        }
    
        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        
        return response
    

### Create Compute PMI component

This component encapsulates a call to the BigQuery stored procedure that calculates item cooccurence. Refer to the preceeding notebooks for more details about item coocurrent calculations.

The component tracks the output `item_cooc` table created by the stored procedure using the TFX (simple) Dataset artifact.

In [None]:
%%writefile compute_pmi.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BigQuery compute PMI component."""

import logging

from google.cloud import bigquery

import tfx
import tensorflow as tf

from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter

from tfx.types.experimental.simple_artifacts import Dataset as BQDataset


@component
def compute_pmi(
    project_id: Parameter[str],
    bq_dataset: Parameter[str],
    min_item_frequency: Parameter[int],
    max_group_size: Parameter[int],
    item_cooc: OutputArtifact[BQDataset]):
    
    stored_proc = f'{bq_dataset}.sp_ComputePMI'
    query = f'''
        DECLARE min_item_frequency INT64;
        DECLARE max_group_size INT64;

        SET min_item_frequency = {min_item_frequency};
        SET max_group_size = {max_group_size};

        CALL {stored_proc}(min_item_frequency, max_group_size);
    '''
    result_table = 'item_cooc'

    logging.info(f'Starting computing PMI...')
  
    client = bigquery.Client(project=project_id)
    query_job = client.query(query)
    query_job.result() # Wait for the job to complete
  
    logging.info(f'Items PMI computation completed. Output in {bq_dataset}.{result_table}.')
  
    # Write the location of the output table to metadata.  
    item_cooc.set_string_custom_property('table_name',
                                         f'{project_id}:{bq_dataset}.{result_table}')


### Create Train Item Matching Model component

This component encapsulates a call to the BigQuery stored procedure that trains the BQML Matrix Factorization model. Refer to the preceeding notebooks for more details about model training.

The component tracks the output `item_matching_model` BQML model created by the stored procedure using the TFX (simple) Model artifact.

In [None]:
%%writefile train_item_matching.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BigQuery compute PMI component."""

import logging

from google.cloud import bigquery

import tfx
import tensorflow as tf

from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter

from tfx.types.experimental.simple_artifacts import Dataset as BQDataset
from tfx.types.standard_artifacts import Model as BQModel


@component
def train_item_matching_model(
    project_id: Parameter[str],
    bq_dataset: Parameter[str],
    dimensions: Parameter[int],
    item_cooc: InputArtifact[BQDataset],
    bq_model: OutputArtifact[BQModel]):
    
    item_cooc_table = item_cooc.get_string_custom_property('table_name')
    stored_proc = f'{bq_dataset}.sp_TrainItemMatchingModel'
    query = f'''
        DECLARE dimensions INT64 DEFAULT {dimensions};
        CALL {stored_proc}(dimensions);
    '''
    model_name = 'item_matching_model'
  
    logging.info(f'Using item co-occurrence table: item_cooc_table')
    logging.info(f'Starting training of the model...')
    
    client = bigquery.Client(project=project_id)
    query_job = client.query(query)
    query_job.result()
  
    logging.info(f'Model training completed. Output in {bq_dataset}.{model_name}.')
  
    # Write the location of the model to metadata. 
    bq_model.set_string_custom_property('model_name',
                                         f'{project_id}:{bq_dataset}.{model_name}')
   
  

### Create Extract Embeddings component

This component encapsulates a call to the BigQuery stored procedure that extracts embdeddings from the model to the staging table. Refer to the preceeding notebooks for more details about embeddings extraction.

The component tracks the output `item_embeddings` table created by the stored procedure using the TFX (simple) Dataset artifact.

In [None]:
%%writefile extract_embeddings.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extracts embeddings to a BQ table."""

import logging

from google.cloud import bigquery

import tfx
import tensorflow as tf

from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter

from tfx.types.experimental.simple_artifacts import Dataset as BQDataset 
from tfx.types.standard_artifacts import Model as BQModel


@component
def extract_embeddings(
    project_id: Parameter[str],
    bq_dataset: Parameter[str],
    bq_model: InputArtifact[BQModel],
    item_embeddings: OutputArtifact[BQDataset]):
  
    embedding_model_name = bq_model.get_string_custom_property('model_name')
    stored_proc = f'{bq_dataset}.sp_ExractEmbeddings'
    query = f'''
        CALL {stored_proc}();
    '''
    embeddings_table = 'item_embeddings'

    logging.info(f'Extracting item embeddings from: {embedding_model_name}')
    
    client = bigquery.Client(project=project_id)
    query_job = client.query(query)
    query_job.result() # Wait for the job to complete
  
    logging.info(f'Embeddings extraction completed. Output in {bq_dataset}.{embeddings_table}')
  
    # Write the location of the output table to metadata.
    item_embeddings.set_string_custom_property('table_name', 
                                                f'{project_id}:{bq_dataset}.{embeddings_table}')
    

 

### Create Export Embeddings component

This component encapsulates a BigQuery table extraction job that extracts the `item_embeddings` table to a GCS location as files in the JSONL format. The format of the extracted files is compatible with the ingestion schema for the ANN Service.

The component tracks the output files location in the TFX (simple) Dataset artifact.

In [None]:
%%writefile export_embeddings.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exports embeddings from a BQ table to a GCS location."""

import logging

from google.cloud import bigquery

import tfx
import tensorflow as tf

from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter

from tfx.types.experimental.simple_artifacts import Dataset 

BQDataset = Dataset

@component
def export_embeddings(
    project_id: Parameter[str],
    gcs_location: Parameter[str],
    item_embeddings_bq: InputArtifact[BQDataset],
    item_embeddings_gcs: OutputArtifact[Dataset]):
    
    filename_pattern = 'embedding-*.json'
    gcs_location = gcs_location.rstrip('/')
    destination_uri = f'{gcs_location}/{filename_pattern}'
    
    _, table_name = item_embeddings_bq.get_string_custom_property('table_name').split(':')
  
    logging.info(f'Exporting item embeddings from: {table_name}')
  
    bq_dataset, table_id = table_name.split('.')
    client = bigquery.Client(project=project_id)
    dataset_ref = bigquery.DatasetReference(project_id, bq_dataset)
    table_ref = dataset_ref.table(table_id)
    job_config = bigquery.job.ExtractJobConfig()
    job_config.destination_format = bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON

    extract_job = client.extract_table(
        table_ref,
        destination_uris=destination_uri,
        job_config=job_config
    )  
    extract_job.result() # Wait for resuls
    
    logging.info(f'Embeddings export completed. Output in {gcs_location}')
  
    # Write the location of the embeddings to metadata.
    item_embeddings_gcs.uri = gcs_location

 

### Create ANN index component

This component encapsulats the calls to the ANN Service to create an ANN Index. 

The component tracks the created index int the TFX custom `ANNIndex` artifact.

In [None]:
%%writefile create_index.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Creates an ANN index."""

import logging

import google.auth
import numpy as np
import tfx
import tensorflow as tf

from google.cloud import bigquery
from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter
from tfx.types.experimental.simple_artifacts import Dataset 

from ann_service import IndexClient
from ann_types import ANNIndex

NUM_NEIGHBOURS = 10
MAX_LEAVES_TO_SEARCH = 200
METRIC = 'DOT_PRODUCT_DISTANCE'
FEATURE_NORM_TYPE = 'UNIT_L2_NORM'
CHILD_NODE_COUNT = 1000
APPROXIMATE_NEIGHBORS_COUNT = 50

@component
def create_index(
    project_id: Parameter[str],
    project_number: Parameter[str],
    region: Parameter[str],
    display_name: Parameter[str],
    dimensions: Parameter[int],
    item_embeddings: InputArtifact[Dataset],
    ann_index: OutputArtifact[ANNIndex]):
    
    index_client = IndexClient(project_id, project_number, region)
    
    logging.info('Creating index:')
    logging.info(f'    Index display name: {display_name}')
    logging.info(f'    Embeddings location: {item_embeddings.uri}')
    
    index_description = display_name
    index_metadata = {
        'contents_delta_uri': item_embeddings.uri,
        'config': {
            'dimensions': dimensions,
            'approximate_neighbors_count': APPROXIMATE_NEIGHBORS_COUNT,
            'distance_measure_type': METRIC,
            'feature_norm_type': FEATURE_NORM_TYPE,
            'tree_ah_config': {
                'child_node_count': CHILD_NODE_COUNT,
                'max_leaves_to_search': MAX_LEAVES_TO_SEARCH
            }
        }
    }
    
    operation_id = index_client.create_index(display_name, 
                                             index_description,
                                             index_metadata)
    response = index_client.wait_for_completion(operation_id, 'Waiting for ANN index', 45)
    index_name = response['name']
    
    logging.info('Index {} created.'.format(index_name))
  
    # Write the index name to metadata.
    ann_index.set_string_custom_property('index_name', 
                                         index_name)
    ann_index.set_string_custom_property('index_display_name', 
                                         display_name)


### Deploy ANN index component

This component deploys an ANN index to an ANN Endpoint. 
The componet tracks the deployed index in the TFX custom `DeployedANNIndex` artifact.

In [None]:
%%writefile deploy_index.py
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deploys an ANN index."""

import logging

import numpy as np
import uuid
import tfx
import tensorflow as tf

from google.cloud import bigquery
from tfx.dsl.component.experimental.decorators import component
from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter
from tfx.types.experimental.simple_artifacts import Dataset 

from ann_service import IndexDeploymentClient
from ann_types import ANNIndex
from ann_types import DeployedANNIndex


@component
def deploy_index(
    project_id: Parameter[str],
    project_number: Parameter[str],
    region: Parameter[str],
    vpc_name: Parameter[str],
    deployed_index_id_prefix: Parameter[str],
    ann_index: InputArtifact[ANNIndex],
    deployed_ann_index: OutputArtifact[DeployedANNIndex]
    ):
    
    deployment_client = IndexDeploymentClient(project_id, 
                                              project_number,
                                              region)
    
    index_name = ann_index.get_string_custom_property('index_name')
    index_display_name = ann_index.get_string_custom_property('index_display_name')
    endpoint_display_name = f'Endpoint for {index_display_name}'
    
    logging.info(f'Creating endpoint: {endpoint_display_name}')
    operation_id = deployment_client.create_endpoint(endpoint_display_name, vpc_name)
    response = deployment_client.wait_for_completion(operation_id, 'Waiting for endpoint', 30)
    endpoint_name = response['name']
    logging.info(f'Endpoint created: {endpoint_name}')
  
    endpoint_id = endpoint_name.split('/')[-1]
    index_id = index_name.split('/')[-1]
    deployed_index_display_name = f'Deployed {index_display_name}'
    deployed_index_id = deployed_index_id_prefix + str(uuid.uuid4())
    
    logging.info(f'Creating deployed index: {deployed_index_id}')
    logging.info(f'                  from: {index_name}')
    operation_id = deployment_client.create_deployment(
        deployed_index_display_name, 
        deployed_index_id,
        endpoint_id,
        index_id)
    response = deployment_client.wait_for_completion(operation_id, 'Waiting for deployment', 60)
    logging.info('Index deployed!')
  
    deployed_index_ip = deployment_client.get_deployment_grpc_ip(
        endpoint_id, deployed_index_id
    )
    # Write the deployed index properties to metadata.
    deployed_ann_index.set_string_custom_property('endpoint_name', 
                                                  endpoint_name)
    deployed_ann_index.set_string_custom_property('deployed_index_id', 
                                                  deployed_index_id)
    deployed_ann_index.set_string_custom_property('index_name', 
                                                  index_name)
    deployed_ann_index.set_string_custom_property('deployed_index_grpc_ip', 
                                                  deployed_index_ip)


## Creating a TFX pipeline

The pipeline automates the process of preparing item embeddings (in BigQuery), training a Matrix Factorization model (in BQML), and creating and deploying an ANN Service index.

The pipeline has a simple sequential flow. The pipeline accepts a set of runtime parameters that define GCP environment settings and embeddings and index assembly parameters.  

In [None]:
import os

# Only required for local run.
from tfx.orchestration.metadata import sqlite_metadata_connection_config

from tfx.orchestration.pipeline import Pipeline
from tfx.orchestration.kubeflow.v2 import kubeflow_v2_dag_runner

from compute_pmi import compute_pmi
from export_embeddings import export_embeddings
from extract_embeddings import extract_embeddings
from train_item_matching import train_item_matching_model
from create_index import create_index
from deploy_index import deploy_index

def ann_pipeline(
    pipeline_name,
    pipeline_root,
    metadata_connection_config,
    project_id,
    project_number,
    region,
    vpc_name,
    bq_dataset_name,
    min_item_frequency,
    max_group_size,
    dimensions,
    embeddings_gcs_location,
    index_display_name,
    deployed_index_id_prefix) -> Pipeline:
    """Implements the SCANN training pipeline."""
 
    pmi_computer = compute_pmi(
        project_id=project_id,
        bq_dataset=bq_dataset_name,
        min_item_frequency=min_item_frequency,
        max_group_size=max_group_size
    )
    
    bqml_trainer = train_item_matching_model(
        project_id=project_id,
        bq_dataset=bq_dataset_name,
        item_cooc=pmi_computer.outputs.item_cooc,
        dimensions=dimensions,
    )
    
    embeddings_extractor = extract_embeddings(
        project_id=project_id,
        bq_dataset=bq_dataset_name,
        bq_model=bqml_trainer.outputs.bq_model
    )
    
    embeddings_exporter = export_embeddings(
        project_id=project_id,
        gcs_location=embeddings_gcs_location,
        item_embeddings_bq=embeddings_extractor.outputs.item_embeddings
    )
    
    index_constructor = create_index(
        project_id=project_id,
        project_number=project_number,
        region=region,
        display_name=index_display_name,
        dimensions=dimensions,
        item_embeddings=embeddings_exporter.outputs.item_embeddings_gcs
    )
    
    index_deployer = deploy_index(
        project_id=project_id,
        project_number=project_number,
        region=region,
        vpc_name=vpc_name,
        deployed_index_id_prefix=deployed_index_id_prefix,
        ann_index=index_constructor.outputs.ann_index
    )

    components = [
        pmi_computer,
        bqml_trainer,
        embeddings_extractor,
        embeddings_exporter,
        index_constructor,
        index_deployer
    ]
    
    return Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        # Only needed for local runs.
        metadata_connection_config=metadata_connection_config,
        components=components)

## Testing the pipeline locally

You will first run the pipeline locally using the Beam runner.

### Clean the metadata and artifacts from the previous runs

In [None]:
pipeline_root = f'/tmp/{PIPELINE_NAME}'
local_mlmd_folder = '/tmp/mlmd'

if tf.io.gfile.exists(pipeline_root):
  print("Removing previous artifacts...")
  tf.io.gfile.rmtree(pipeline_root)
if tf.io.gfile.exists(local_mlmd_folder):
  print("Removing local mlmd SQLite...")
  tf.io.gfile.rmtree(local_mlmd_folder)
print("Creating mlmd directory: ", local_mlmd_folder)
tf.io.gfile.mkdir(local_mlmd_folder)
print("Creating pipeline root folder: ", pipeline_root)
tf.io.gfile.mkdir(pipeline_root)

### Set pipeline parameters and create the pipeline

In [None]:
bq_dataset_name = 'song_embeddings'
index_display_name = 'Song embeddings'
deployed_index_id_prefix = 'deployed_song_embeddings_'
min_item_frequency = 15
max_group_size = 100
dimensions = 50
embeddings_gcs_location = f'gs://{BUCKET_NAME}/embeddings'

metadata_connection_config = sqlite_metadata_connection_config(
    os.path.join(local_mlmd_folder, 'metadata.sqlite'))

pipeline = ann_pipeline(
    pipeline_name=PIPELINE_NAME,
    pipeline_root=pipeline_root,
    metadata_connection_config=metadata_connection_config,
    project_id=PROJECT_ID,
    project_number=PROJECT_NUMBER,
    region=REGION,
    vpc_name=VPC_NAME,
    bq_dataset_name=bq_dataset_name,
    index_display_name=index_display_name,
    deployed_index_id_prefix=deployed_index_id_prefix,
    min_item_frequency=min_item_frequency,
    max_group_size=max_group_size,
    dimensions=dimensions,
    embeddings_gcs_location=embeddings_gcs_location
)

### Start the run

In [None]:
logging.getLogger().setLevel(logging.INFO)

BeamDagRunner().run(pipeline)

### Inspect produced metadata

During the execution of the pipeline, the inputs and outputs of each component have been tracked in ML Metadata. 

In [None]:
from ml_metadata import metadata_store
from ml_metadata.proto import metadata_store_pb2

connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = os.path.join(local_mlmd_folder, 'metadata.sqlite')
connection_config.sqlite.connection_mode = 3 # READWRITE_OPENCREATE
store = metadata_store.MetadataStore(connection_config)
store.get_artifacts()

# NOTICE. The following code does not work with ANN Service Experimental. It will be finalized when the service moves to the Preview stage.

## Running the pipeline on AI Platform Pipelines

You will now run the pipeline on AI Platform Pipelines (Unified)

### Package custom components into a container

The modules containing custom components must be first package as a docker container image, which is a derivative of the standard TFX image.

#### Create a Dockerfile

In [None]:
%%writefile Dockerfile
FROM gcr.io/tfx-oss-public/tfx:0.25.0
WORKDIR /pipeline
COPY ./ ./
ENV PYTHONPATH="/pipeline:${PYTHONPATH}"

#### Build and push the docker image to Container Registry

In [None]:
!gcloud builds submit --tag gcr.io/{PROJECT_ID}/caip-tfx-custom:{USER} .

#### Create AI Platform Pipelines client

In [None]:
from aiplatform.pipelines import client

aipp_client = client.Client(
    project_id=PROJECT_ID,
    region=REGION,
    api_key=API_KEY
)

#### Set the the parameters for AIPP execution and create the pipeline

In [None]:
metadata_connection_config = None
pipeline_root = PIPELINE_ROOT

pipeline = ann_pipeline(
    pipeline_name=PIPELINE_NAME,
    pipeline_root=pipeline_root,
    metadata_connection_config=metadata_connection_config,
    project_id=PROJECT_ID,
    project_number=PROJECT_NUMBER,
    region=REGION,
    vpc_name=VPC_NAME,
    bq_dataset_name=bq_dataset_name,
    index_display_name=index_display_name,
    deployed_index_id_prefix=deployed_index_id_prefix,
    min_item_frequency=min_item_frequency,
    max_group_size=max_group_size,
    dimensions=dimensions,
    embeddings_gcs_location=embeddings_gcs_location
)

#### Compile the pipeline

In [None]:
config = kubeflow_v2_dag_runner.KubeflowV2DagRunnerConfig(
    project_id=PROJECT_ID,
    display_name=PIPELINE_NAME,
    default_image='gcr.io/{}/caip-tfx-custom:{}'.format(PROJECT_ID, USER))
runner = kubeflow_v2_dag_runner.KubeflowV2DagRunner(
    config=config,
    output_filename='pipeline.json')
runner.compile(
    pipeline,
    write_out=True)

#### Submit the pipeline run

In [None]:
aipp_client.create_run_from_job_spec('pipeline.json')

## License

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

See the License for the specific language governing permissions and limitations under the License.

**This is not an official Google product but sample code provided for an educational purpose**