retail/recommendation-system/bqml-scann/ann02_run_pipeline.ipynb (1,464 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Low-latency item-to-item recommendation system - Orchestrating with TFX\n",
"\n",
"## Overview\n",
"\n",
"This notebook is a part of the series that describes the process of implementing a [**Low-latency item-to-item recommendation system**](https://github.com/jarokaz/analytics-componentized-patterns/tree/master/retail/recommendation-system/bqml-scann).\n",
"\n",
"This notebook demonstrates how to use TFX and AI Platform Pipelines (Unified) to operationalize the workflow that creates embeddings and builds and deploys an ANN Service index. \n",
"\n",
"In the notebook you go through the following steps.\n",
"\n",
"1. Creating TFX custom components that encapsulate operations on BQ, BQML and ANN Service.\n",
"2. Creating a TFX pipeline that automates the processes of creating embeddings and deploying an ANN Index \n",
"3. Testing the pipeline locally using Beam runner.\n",
"4. Compiling the pipeline to the TFX IR format for execution on AI Platform Pipelines (Unified).\n",
"5. Submitting pipeline runs.\n",
"\n",
"This notebook was designed to run on [AI Platform Notebooks](https://cloud.google.com/ai-platform-notebooks). Before running the notebook make sure that you have completed the setup steps as described in the [README file](README.md).\n",
"\n",
"### TFX Pipeline Design\n",
"\n",
"The below diagram depicts the TFX pipeline that you will implement in this notebook. Each step of the pipeline is implemented as a [TFX Custom Python function component](https://www.tensorflow.org/tfx/guide/custom_function_component). The components track the relevant metadata in AI Platform (Unfied) ML Metadata using both standard and custom metadata types. \n",
"\n",
"\n",
"\n",
"1. The first step of the pipeline is to compute item co-occurence. This is done by calling the `sp_ComputePMI` stored procedure created in the preceeding notebooks. \n",
"2. Next, the BQML Matrix Factorization model is created. The model training code is encapsulated in the `sp_TrainItemMatchingModel` stored procedure.\n",
"3. Item embeddings are extracted from the trained model weights and stored in a BQ table. The component calls the `sp_ExtractEmbeddings` stored procedure that implements the extraction logic.\n",
"4. The embeddings are exported in the JSONL format to the GCS location using the BigQuery extract job.\n",
"5. The embeddings in the JSONL format are used to create an ANN index by calling the ANN Service Control Plane REST API.\n",
"6. Finally, the ANN index is deployed to an ANN endpoint.\n",
"\n",
"All steps and their inputs and outputs are tracked in the AI Platform (Unified) ML Metadata service.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting up the notebook's environment"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Install AI Platform Pipelines client library\n",
"\n",
"For AI Platform Pipelines (Unified), which is in the Experimental stage, you need to download and install the AI Platform client library on top of the KFP and TFX SDKs that were installed as part of the initial environment setup."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"AIP_CLIENT_WHEEL = 'aiplatform_pipelines_client-0.1.0.caip20201123-py3-none-any.whl'\n",
"AIP_CLIENT_WHEEL_GCS_LOCATION = f'gs://cloud-aiplatform-pipelines/releases/20201123/{AIP_CLIENT_WHEEL}'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!gsutil cp {AIP_CLIENT_WHEEL_GCS_LOCATION} {AIP_CLIENT_WHEEL}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install {AIP_CLIENT_WHEEL}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Restart the kernel."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import IPython\n",
"app = IPython.Application.instance()\n",
"app.kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import notebook dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import tfx\n",
"import tensorflow as tf\n",
"\n",
"from aiplatform.pipelines import client\n",
"from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner\n",
"\n",
"print('TFX Version: ', tfx.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure GCP environment"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"-----------------\n",
"\n",
"**If you're on AI Platform Notebooks**, authenticate with Google Cloud before running the next section, by running\n",
"```sh\n",
"gcloud auth login\n",
"```\n",
"**in the Terminal window** (which you can open via **File** > **New** in the menu). You only need to do this once per notebook instance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the following constants to the values reflecting your environment:\n",
"\n",
"* `PROJECT_ID` - your GCP project ID\n",
"* `PROJECT_NUMBER` - your GCP project number\n",
"* `BUCKET_NAME` - a name of the GCS bucket that will be used to host artifacts created by the pipeline\n",
"* `PIPELINE_NAME_SUFFIX` - a suffix appended to the standard pipeline name. You can change to differentiate between pipelines from different users in a classroom environment\n",
"* `API_KEY` - a GCP API key\n",
"* `VPC_NAME` - a name of the GCP VPC to use for the index deployments. \n",
"* `REGION` - a compute region. Don't change the default - `us-central` - while the ANN Service is in the experimental stage\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PROJECT_ID = '' # <---CHANGE THIS\n",
"PROJECT_NUMBER = '' # <---CHANGE THIS\n",
"API_KEY = '' # <---CHANGE THIS\n",
"USER = 'user' # <---CHANGE THIS\n",
"BUCKET_NAME = 'jk-ann-staging' # <---CHANGE THIS\n",
"VPC_NAME = 'default' # <---CHANGE THIS IF USING A DIFFERENT VPC\n",
"\n",
"REGION = 'us-central1'\n",
"PIPELINE_NAME = \"ann-pipeline-{}\".format(USER)\n",
"PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(BUCKET_NAME, PIPELINE_NAME)\n",
"PATH=%env PATH\n",
"%env PATH={PATH}:/home/jupyter/.local/bin\n",
" \n",
"print('PIPELINE_ROOT: {}'.format(PIPELINE_ROOT))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining custom components\n",
"\n",
"In this section of the notebook you define a set of custom TFX components that encapsulate BQ, BQML and ANN Service calls. The components are [TFX Custom Python function components](https://www.tensorflow.org/tfx/guide/custom_function_component). \n",
"\n",
"Each component is created as a separate Python module. You also create a couple of helper modules that encapsulate Python functions and classess used across the custom components. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Remove files created in the previous executions of the notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"component_folder = 'bq_components'\n",
"\n",
"if tf.io.gfile.exists(component_folder):\n",
" print('Removing older file')\n",
" tf.io.gfile.rmtree(component_folder)\n",
"print('Creating component folder')\n",
"tf.io.gfile.mkdir(component_folder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%cd {component_folder}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define custom types for ANN service artifacts\n",
"\n",
"This module defines a couple of custom TFX artifacts to track ANN Service indexes and index deployments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile ann_types.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"Custom types for managing ANN artifacts.\"\"\"\n",
"\n",
"from tfx.types import artifact\n",
"\n",
"class ANNIndex(artifact.Artifact):\n",
" TYPE_NAME = 'ANNIndex'\n",
" \n",
"class DeployedANNIndex(artifact.Artifact):\n",
" TYPE_NAME = 'DeployedANNIndex'\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a wrapper around ANN Service REST API\n",
"\n",
"This module provides a convenience wrapper around ANN Service REST API. In the experimental stage, the ANN Service does not have an \"official\" Python client SDK nor it is supported by the Google Discovery API."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile ann_service.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"Helper classes encapsulating ANN Service REST API.\"\"\"\n",
"\n",
"import datetime\n",
"import logging\n",
"import json\n",
"import time\n",
"\n",
"import google.auth\n",
"\n",
"class ANNClient(object):\n",
" \"\"\"Base ANN Service client.\"\"\"\n",
" \n",
" def __init__(self, project_id, project_number, region):\n",
" credentials, _ = google.auth.default()\n",
" self.authed_session = google.auth.transport.requests.AuthorizedSession(credentials)\n",
" self.ann_endpoint = f'{region}-aiplatform.googleapis.com'\n",
" self.ann_parent = f'https://{self.ann_endpoint}/v1alpha1/projects/{project_id}/locations/{region}'\n",
" self.project_id = project_id\n",
" self.project_number = project_number\n",
" self.region = region\n",
" \n",
" def wait_for_completion(self, operation_id, message, sleep_time):\n",
" \"\"\"Waits for a completion of a long running operation.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/operations/{operation_id}'\n",
"\n",
" start_time = datetime.datetime.utcnow()\n",
" while True:\n",
" response = self.authed_session.get(api_url)\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.json())\n",
" if 'done' in response.json().keys():\n",
" logging.info('Operation completed!')\n",
" break\n",
" elapsed_time = datetime.datetime.utcnow() - start_time\n",
" logging.info('{}. Elapsed time since start: {}.'.format(\n",
" message, str(elapsed_time)))\n",
" time.sleep(sleep_time)\n",
" \n",
" return response.json()['response']\n",
"\n",
"\n",
"class IndexClient(ANNClient):\n",
" \"\"\"Encapsulates a subset of control plane APIs \n",
" that manage ANN indexes.\"\"\"\n",
"\n",
" def __init__(self, project_id, project_number, region):\n",
" super().__init__(project_id, project_number, region)\n",
"\n",
" def create_index(self, display_name, description, metadata):\n",
" \"\"\"Creates an ANN Index.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/indexes'\n",
" \n",
" request_body = {\n",
" 'display_name': display_name,\n",
" 'description': description,\n",
" 'metadata': metadata\n",
" }\n",
" \n",
" response = self.authed_session.post(api_url, data=json.dumps(request_body))\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.text)\n",
" operation_id = response.json()['name'].split('/')[-1]\n",
" \n",
" return operation_id\n",
"\n",
" def list_indexes(self, display_name=None):\n",
" \"\"\"Lists all indexes with a given display name or\n",
" all indexes if the display_name is not provided.\"\"\"\n",
" \n",
" if display_name:\n",
" api_url = f'{self.ann_parent}/indexes?filter=display_name=\"{display_name}\"'\n",
" else:\n",
" api_url = f'{self.ann_parent}/indexes'\n",
"\n",
" response = self.authed_session.get(api_url).json()\n",
"\n",
" return response['indexes'] if response else []\n",
" \n",
" def delete_index(self, index_id):\n",
" \"\"\"Deletes an ANN index.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/indexes/{index_id}'\n",
" response = self.authed_session.delete(api_url)\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.text)\n",
"\n",
"\n",
"class IndexDeploymentClient(ANNClient):\n",
" \"\"\"Encapsulates a subset of control plane APIs \n",
" that manage ANN endpoints and deployments.\"\"\"\n",
" \n",
" def __init__(self, project_id, project_number, region):\n",
" super().__init__(project_id, project_number, region)\n",
"\n",
" def create_endpoint(self, display_name, vpc_name):\n",
" \"\"\"Creates an ANN endpoint.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/indexEndpoints'\n",
" network_name = f'projects/{self.project_number}/global/networks/{vpc_name}'\n",
"\n",
" request_body = {\n",
" 'display_name': display_name,\n",
" 'network': network_name\n",
" }\n",
"\n",
" response = self.authed_session.post(api_url, data=json.dumps(request_body))\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.text)\n",
" operation_id = response.json()['name'].split('/')[-1]\n",
" \n",
" return operation_id\n",
" \n",
" def list_endpoints(self, display_name=None):\n",
" \"\"\"Lists all ANN endpoints with a given display name or\n",
" all endpoints in the project if the display_name is not provided.\"\"\"\n",
" \n",
" if display_name:\n",
" api_url = f'{self.ann_parent}/indexEndpoints?filter=display_name=\"{display_name}\"'\n",
" else:\n",
" api_url = f'{self.ann_parent}/indexEndpoints'\n",
"\n",
" response = self.authed_session.get(api_url).json()\n",
" \n",
" return response['indexEndpoints'] if response else []\n",
" \n",
" def delete_endpoint(self, endpoint_id):\n",
" \"\"\"Deletes an ANN endpoint.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}'\n",
" \n",
" response = self.authed_session.delete(api_url)\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.text)\n",
" \n",
" return response.json()\n",
" \n",
" def create_deployment(self, display_name, deployment_id, endpoint_id, index_id):\n",
" \"\"\"Deploys an ANN index to an endpoint.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}:deployIndex'\n",
" index_name = f'projects/{self.project_number}/locations/{self.region}/indexes/{index_id}'\n",
"\n",
" request_body = {\n",
" 'deployed_index': {\n",
" 'id': deployment_id,\n",
" 'index': index_name,\n",
" 'display_name': display_name\n",
" }\n",
" }\n",
"\n",
" response = self.authed_session.post(api_url, data=json.dumps(request_body))\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.text)\n",
" operation_id = response.json()['name'].split('/')[-1]\n",
" \n",
" return operation_id\n",
" \n",
" def get_deployment_grpc_ip(self, endpoint_id, deployment_id):\n",
" \"\"\"Returns a private IP address for a gRPC interface to \n",
" an Index deployment.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}'\n",
"\n",
" response = self.authed_session.get(api_url)\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.text)\n",
" \n",
" endpoint_ip = None\n",
" if 'deployedIndexes' in response.json().keys():\n",
" for deployment in response.json()['deployedIndexes']:\n",
" if deployment['id'] == deployment_id:\n",
" endpoint_ip = deployment['privateEndpoints']['matchGrpcAddress']\n",
" \n",
" return endpoint_ip\n",
"\n",
" \n",
" def delete_deployment(self, endpoint_id, deployment_id):\n",
" \"\"\"Undeployes an index from an endpoint.\"\"\"\n",
" \n",
" api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}:undeployIndex'\n",
" \n",
" request_body = {\n",
" 'deployed_index_id': deployment_id\n",
" }\n",
" \n",
" response = self.authed_session.post(api_url, data=json.dumps(request_body))\n",
" if response.status_code != 200:\n",
" raise RuntimeError(response.text)\n",
" \n",
" return response\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create Compute PMI component\n",
"\n",
"This component encapsulates a call to the BigQuery stored procedure that calculates item cooccurence. Refer to the preceeding notebooks for more details about item coocurrent calculations.\n",
"\n",
"The component tracks the output `item_cooc` table created by the stored procedure using the TFX (simple) Dataset artifact."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile compute_pmi.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"BigQuery compute PMI component.\"\"\"\n",
"\n",
"import logging\n",
"\n",
"from google.cloud import bigquery\n",
"\n",
"import tfx\n",
"import tensorflow as tf\n",
"\n",
"from tfx.dsl.component.experimental.decorators import component\n",
"from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter\n",
"\n",
"from tfx.types.experimental.simple_artifacts import Dataset as BQDataset\n",
"\n",
"\n",
"@component\n",
"def compute_pmi(\n",
" project_id: Parameter[str],\n",
" bq_dataset: Parameter[str],\n",
" min_item_frequency: Parameter[int],\n",
" max_group_size: Parameter[int],\n",
" item_cooc: OutputArtifact[BQDataset]):\n",
" \n",
" stored_proc = f'{bq_dataset}.sp_ComputePMI'\n",
" query = f'''\n",
" DECLARE min_item_frequency INT64;\n",
" DECLARE max_group_size INT64;\n",
"\n",
" SET min_item_frequency = {min_item_frequency};\n",
" SET max_group_size = {max_group_size};\n",
"\n",
" CALL {stored_proc}(min_item_frequency, max_group_size);\n",
" '''\n",
" result_table = 'item_cooc'\n",
"\n",
" logging.info(f'Starting computing PMI...')\n",
" \n",
" client = bigquery.Client(project=project_id)\n",
" query_job = client.query(query)\n",
" query_job.result() # Wait for the job to complete\n",
" \n",
" logging.info(f'Items PMI computation completed. Output in {bq_dataset}.{result_table}.')\n",
" \n",
" # Write the location of the output table to metadata. \n",
" item_cooc.set_string_custom_property('table_name',\n",
" f'{project_id}:{bq_dataset}.{result_table}')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create Train Item Matching Model component\n",
"\n",
"This component encapsulates a call to the BigQuery stored procedure that trains the BQML Matrix Factorization model. Refer to the preceeding notebooks for more details about model training.\n",
"\n",
"The component tracks the output `item_matching_model` BQML model created by the stored procedure using the TFX (simple) Model artifact."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile train_item_matching.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"BigQuery compute PMI component.\"\"\"\n",
"\n",
"import logging\n",
"\n",
"from google.cloud import bigquery\n",
"\n",
"import tfx\n",
"import tensorflow as tf\n",
"\n",
"from tfx.dsl.component.experimental.decorators import component\n",
"from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter\n",
"\n",
"from tfx.types.experimental.simple_artifacts import Dataset as BQDataset\n",
"from tfx.types.standard_artifacts import Model as BQModel\n",
"\n",
"\n",
"@component\n",
"def train_item_matching_model(\n",
" project_id: Parameter[str],\n",
" bq_dataset: Parameter[str],\n",
" dimensions: Parameter[int],\n",
" item_cooc: InputArtifact[BQDataset],\n",
" bq_model: OutputArtifact[BQModel]):\n",
" \n",
" item_cooc_table = item_cooc.get_string_custom_property('table_name')\n",
" stored_proc = f'{bq_dataset}.sp_TrainItemMatchingModel'\n",
" query = f'''\n",
" DECLARE dimensions INT64 DEFAULT {dimensions};\n",
" CALL {stored_proc}(dimensions);\n",
" '''\n",
" model_name = 'item_matching_model'\n",
" \n",
" logging.info(f'Using item co-occurrence table: item_cooc_table')\n",
" logging.info(f'Starting training of the model...')\n",
" \n",
" client = bigquery.Client(project=project_id)\n",
" query_job = client.query(query)\n",
" query_job.result()\n",
" \n",
" logging.info(f'Model training completed. Output in {bq_dataset}.{model_name}.')\n",
" \n",
" # Write the location of the model to metadata. \n",
" bq_model.set_string_custom_property('model_name',\n",
" f'{project_id}:{bq_dataset}.{model_name}')\n",
" \n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create Extract Embeddings component\n",
"\n",
"This component encapsulates a call to the BigQuery stored procedure that extracts embdeddings from the model to the staging table. Refer to the preceeding notebooks for more details about embeddings extraction.\n",
"\n",
"The component tracks the output `item_embeddings` table created by the stored procedure using the TFX (simple) Dataset artifact."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile extract_embeddings.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"Extracts embeddings to a BQ table.\"\"\"\n",
"\n",
"import logging\n",
"\n",
"from google.cloud import bigquery\n",
"\n",
"import tfx\n",
"import tensorflow as tf\n",
"\n",
"from tfx.dsl.component.experimental.decorators import component\n",
"from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter\n",
"\n",
"from tfx.types.experimental.simple_artifacts import Dataset as BQDataset \n",
"from tfx.types.standard_artifacts import Model as BQModel\n",
"\n",
"\n",
"@component\n",
"def extract_embeddings(\n",
" project_id: Parameter[str],\n",
" bq_dataset: Parameter[str],\n",
" bq_model: InputArtifact[BQModel],\n",
" item_embeddings: OutputArtifact[BQDataset]):\n",
" \n",
" embedding_model_name = bq_model.get_string_custom_property('model_name')\n",
" stored_proc = f'{bq_dataset}.sp_ExractEmbeddings'\n",
" query = f'''\n",
" CALL {stored_proc}();\n",
" '''\n",
" embeddings_table = 'item_embeddings'\n",
"\n",
" logging.info(f'Extracting item embeddings from: {embedding_model_name}')\n",
" \n",
" client = bigquery.Client(project=project_id)\n",
" query_job = client.query(query)\n",
" query_job.result() # Wait for the job to complete\n",
" \n",
" logging.info(f'Embeddings extraction completed. Output in {bq_dataset}.{embeddings_table}')\n",
" \n",
" # Write the location of the output table to metadata.\n",
" item_embeddings.set_string_custom_property('table_name', \n",
" f'{project_id}:{bq_dataset}.{embeddings_table}')\n",
" \n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create Export Embeddings component\n",
"\n",
"This component encapsulates a BigQuery table extraction job that extracts the `item_embeddings` table to a GCS location as files in the JSONL format. The format of the extracted files is compatible with the ingestion schema for the ANN Service.\n",
"\n",
"The component tracks the output files location in the TFX (simple) Dataset artifact."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile export_embeddings.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"Exports embeddings from a BQ table to a GCS location.\"\"\"\n",
"\n",
"import logging\n",
"\n",
"from google.cloud import bigquery\n",
"\n",
"import tfx\n",
"import tensorflow as tf\n",
"\n",
"from tfx.dsl.component.experimental.decorators import component\n",
"from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter\n",
"\n",
"from tfx.types.experimental.simple_artifacts import Dataset \n",
"\n",
"BQDataset = Dataset\n",
"\n",
"@component\n",
"def export_embeddings(\n",
" project_id: Parameter[str],\n",
" gcs_location: Parameter[str],\n",
" item_embeddings_bq: InputArtifact[BQDataset],\n",
" item_embeddings_gcs: OutputArtifact[Dataset]):\n",
" \n",
" filename_pattern = 'embedding-*.json'\n",
" gcs_location = gcs_location.rstrip('/')\n",
" destination_uri = f'{gcs_location}/{filename_pattern}'\n",
" \n",
" _, table_name = item_embeddings_bq.get_string_custom_property('table_name').split(':')\n",
" \n",
" logging.info(f'Exporting item embeddings from: {table_name}')\n",
" \n",
" bq_dataset, table_id = table_name.split('.')\n",
" client = bigquery.Client(project=project_id)\n",
" dataset_ref = bigquery.DatasetReference(project_id, bq_dataset)\n",
" table_ref = dataset_ref.table(table_id)\n",
" job_config = bigquery.job.ExtractJobConfig()\n",
" job_config.destination_format = bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON\n",
"\n",
" extract_job = client.extract_table(\n",
" table_ref,\n",
" destination_uris=destination_uri,\n",
" job_config=job_config\n",
" ) \n",
" extract_job.result() # Wait for resuls\n",
" \n",
" logging.info(f'Embeddings export completed. Output in {gcs_location}')\n",
" \n",
" # Write the location of the embeddings to metadata.\n",
" item_embeddings_gcs.uri = gcs_location\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create ANN index component\n",
"\n",
"This component encapsulats the calls to the ANN Service to create an ANN Index. \n",
"\n",
"The component tracks the created index int the TFX custom `ANNIndex` artifact."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile create_index.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"Creates an ANN index.\"\"\"\n",
"\n",
"import logging\n",
"\n",
"import google.auth\n",
"import numpy as np\n",
"import tfx\n",
"import tensorflow as tf\n",
"\n",
"from google.cloud import bigquery\n",
"from tfx.dsl.component.experimental.decorators import component\n",
"from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter\n",
"from tfx.types.experimental.simple_artifacts import Dataset \n",
"\n",
"from ann_service import IndexClient\n",
"from ann_types import ANNIndex\n",
"\n",
"NUM_NEIGHBOURS = 10\n",
"MAX_LEAVES_TO_SEARCH = 200\n",
"METRIC = 'DOT_PRODUCT_DISTANCE'\n",
"FEATURE_NORM_TYPE = 'UNIT_L2_NORM'\n",
"CHILD_NODE_COUNT = 1000\n",
"APPROXIMATE_NEIGHBORS_COUNT = 50\n",
"\n",
"@component\n",
"def create_index(\n",
" project_id: Parameter[str],\n",
" project_number: Parameter[str],\n",
" region: Parameter[str],\n",
" display_name: Parameter[str],\n",
" dimensions: Parameter[int],\n",
" item_embeddings: InputArtifact[Dataset],\n",
" ann_index: OutputArtifact[ANNIndex]):\n",
" \n",
" index_client = IndexClient(project_id, project_number, region)\n",
" \n",
" logging.info('Creating index:')\n",
" logging.info(f' Index display name: {display_name}')\n",
" logging.info(f' Embeddings location: {item_embeddings.uri}')\n",
" \n",
" index_description = display_name\n",
" index_metadata = {\n",
" 'contents_delta_uri': item_embeddings.uri,\n",
" 'config': {\n",
" 'dimensions': dimensions,\n",
" 'approximate_neighbors_count': APPROXIMATE_NEIGHBORS_COUNT,\n",
" 'distance_measure_type': METRIC,\n",
" 'feature_norm_type': FEATURE_NORM_TYPE,\n",
" 'tree_ah_config': {\n",
" 'child_node_count': CHILD_NODE_COUNT,\n",
" 'max_leaves_to_search': MAX_LEAVES_TO_SEARCH\n",
" }\n",
" }\n",
" }\n",
" \n",
" operation_id = index_client.create_index(display_name, \n",
" index_description,\n",
" index_metadata)\n",
" response = index_client.wait_for_completion(operation_id, 'Waiting for ANN index', 45)\n",
" index_name = response['name']\n",
" \n",
" logging.info('Index {} created.'.format(index_name))\n",
" \n",
" # Write the index name to metadata.\n",
" ann_index.set_string_custom_property('index_name', \n",
" index_name)\n",
" ann_index.set_string_custom_property('index_display_name', \n",
" display_name)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Deploy ANN index component\n",
"\n",
"This component deploys an ANN index to an ANN Endpoint. \n",
"The componet tracks the deployed index in the TFX custom `DeployedANNIndex` artifact."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile deploy_index.py\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"Deploys an ANN index.\"\"\"\n",
"\n",
"import logging\n",
"\n",
"import numpy as np\n",
"import uuid\n",
"import tfx\n",
"import tensorflow as tf\n",
"\n",
"from google.cloud import bigquery\n",
"from tfx.dsl.component.experimental.decorators import component\n",
"from tfx.dsl.component.experimental.annotations import InputArtifact, OutputArtifact, Parameter\n",
"from tfx.types.experimental.simple_artifacts import Dataset \n",
"\n",
"from ann_service import IndexDeploymentClient\n",
"from ann_types import ANNIndex\n",
"from ann_types import DeployedANNIndex\n",
"\n",
"\n",
"@component\n",
"def deploy_index(\n",
" project_id: Parameter[str],\n",
" project_number: Parameter[str],\n",
" region: Parameter[str],\n",
" vpc_name: Parameter[str],\n",
" deployed_index_id_prefix: Parameter[str],\n",
" ann_index: InputArtifact[ANNIndex],\n",
" deployed_ann_index: OutputArtifact[DeployedANNIndex]\n",
" ):\n",
" \n",
" deployment_client = IndexDeploymentClient(project_id, \n",
" project_number,\n",
" region)\n",
" \n",
" index_name = ann_index.get_string_custom_property('index_name')\n",
" index_display_name = ann_index.get_string_custom_property('index_display_name')\n",
" endpoint_display_name = f'Endpoint for {index_display_name}'\n",
" \n",
" logging.info(f'Creating endpoint: {endpoint_display_name}')\n",
" operation_id = deployment_client.create_endpoint(endpoint_display_name, vpc_name)\n",
" response = deployment_client.wait_for_completion(operation_id, 'Waiting for endpoint', 30)\n",
" endpoint_name = response['name']\n",
" logging.info(f'Endpoint created: {endpoint_name}')\n",
" \n",
" endpoint_id = endpoint_name.split('/')[-1]\n",
" index_id = index_name.split('/')[-1]\n",
" deployed_index_display_name = f'Deployed {index_display_name}'\n",
" deployed_index_id = deployed_index_id_prefix + str(uuid.uuid4())\n",
" \n",
" logging.info(f'Creating deployed index: {deployed_index_id}')\n",
" logging.info(f' from: {index_name}')\n",
" operation_id = deployment_client.create_deployment(\n",
" deployed_index_display_name, \n",
" deployed_index_id,\n",
" endpoint_id,\n",
" index_id)\n",
" response = deployment_client.wait_for_completion(operation_id, 'Waiting for deployment', 60)\n",
" logging.info('Index deployed!')\n",
" \n",
" deployed_index_ip = deployment_client.get_deployment_grpc_ip(\n",
" endpoint_id, deployed_index_id\n",
" )\n",
" # Write the deployed index properties to metadata.\n",
" deployed_ann_index.set_string_custom_property('endpoint_name', \n",
" endpoint_name)\n",
" deployed_ann_index.set_string_custom_property('deployed_index_id', \n",
" deployed_index_id)\n",
" deployed_ann_index.set_string_custom_property('index_name', \n",
" index_name)\n",
" deployed_ann_index.set_string_custom_property('deployed_index_grpc_ip', \n",
" deployed_index_ip)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a TFX pipeline\n",
"\n",
"The pipeline automates the process of preparing item embeddings (in BigQuery), training a Matrix Factorization model (in BQML), and creating and deploying an ANN Service index.\n",
"\n",
"The pipeline has a simple sequential flow. The pipeline accepts a set of runtime parameters that define GCP environment settings and embeddings and index assembly parameters. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Only required for local run.\n",
"from tfx.orchestration.metadata import sqlite_metadata_connection_config\n",
"\n",
"from tfx.orchestration.pipeline import Pipeline\n",
"from tfx.orchestration.kubeflow.v2 import kubeflow_v2_dag_runner\n",
"\n",
"from compute_pmi import compute_pmi\n",
"from export_embeddings import export_embeddings\n",
"from extract_embeddings import extract_embeddings\n",
"from train_item_matching import train_item_matching_model\n",
"from create_index import create_index\n",
"from deploy_index import deploy_index\n",
"\n",
"def ann_pipeline(\n",
" pipeline_name,\n",
" pipeline_root,\n",
" metadata_connection_config,\n",
" project_id,\n",
" project_number,\n",
" region,\n",
" vpc_name,\n",
" bq_dataset_name,\n",
" min_item_frequency,\n",
" max_group_size,\n",
" dimensions,\n",
" embeddings_gcs_location,\n",
" index_display_name,\n",
" deployed_index_id_prefix) -> Pipeline:\n",
" \"\"\"Implements the SCANN training pipeline.\"\"\"\n",
" \n",
" pmi_computer = compute_pmi(\n",
" project_id=project_id,\n",
" bq_dataset=bq_dataset_name,\n",
" min_item_frequency=min_item_frequency,\n",
" max_group_size=max_group_size\n",
" )\n",
" \n",
" bqml_trainer = train_item_matching_model(\n",
" project_id=project_id,\n",
" bq_dataset=bq_dataset_name,\n",
" item_cooc=pmi_computer.outputs.item_cooc,\n",
" dimensions=dimensions,\n",
" )\n",
" \n",
" embeddings_extractor = extract_embeddings(\n",
" project_id=project_id,\n",
" bq_dataset=bq_dataset_name,\n",
" bq_model=bqml_trainer.outputs.bq_model\n",
" )\n",
" \n",
" embeddings_exporter = export_embeddings(\n",
" project_id=project_id,\n",
" gcs_location=embeddings_gcs_location,\n",
" item_embeddings_bq=embeddings_extractor.outputs.item_embeddings\n",
" )\n",
" \n",
" index_constructor = create_index(\n",
" project_id=project_id,\n",
" project_number=project_number,\n",
" region=region,\n",
" display_name=index_display_name,\n",
" dimensions=dimensions,\n",
" item_embeddings=embeddings_exporter.outputs.item_embeddings_gcs\n",
" )\n",
" \n",
" index_deployer = deploy_index(\n",
" project_id=project_id,\n",
" project_number=project_number,\n",
" region=region,\n",
" vpc_name=vpc_name,\n",
" deployed_index_id_prefix=deployed_index_id_prefix,\n",
" ann_index=index_constructor.outputs.ann_index\n",
" )\n",
"\n",
" components = [\n",
" pmi_computer,\n",
" bqml_trainer,\n",
" embeddings_extractor,\n",
" embeddings_exporter,\n",
" index_constructor,\n",
" index_deployer\n",
" ]\n",
" \n",
" return Pipeline(\n",
" pipeline_name=pipeline_name,\n",
" pipeline_root=pipeline_root,\n",
" # Only needed for local runs.\n",
" metadata_connection_config=metadata_connection_config,\n",
" components=components)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Testing the pipeline locally\n",
"\n",
"You will first run the pipeline locally using the Beam runner."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Clean the metadata and artifacts from the previous runs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipeline_root = f'/tmp/{PIPELINE_NAME}'\n",
"local_mlmd_folder = '/tmp/mlmd'\n",
"\n",
"if tf.io.gfile.exists(pipeline_root):\n",
" print(\"Removing previous artifacts...\")\n",
" tf.io.gfile.rmtree(pipeline_root)\n",
"if tf.io.gfile.exists(local_mlmd_folder):\n",
" print(\"Removing local mlmd SQLite...\")\n",
" tf.io.gfile.rmtree(local_mlmd_folder)\n",
"print(\"Creating mlmd directory: \", local_mlmd_folder)\n",
"tf.io.gfile.mkdir(local_mlmd_folder)\n",
"print(\"Creating pipeline root folder: \", pipeline_root)\n",
"tf.io.gfile.mkdir(pipeline_root)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set pipeline parameters and create the pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bq_dataset_name = 'song_embeddings'\n",
"index_display_name = 'Song embeddings'\n",
"deployed_index_id_prefix = 'deployed_song_embeddings_'\n",
"min_item_frequency = 15\n",
"max_group_size = 100\n",
"dimensions = 50\n",
"embeddings_gcs_location = f'gs://{BUCKET_NAME}/embeddings'\n",
"\n",
"metadata_connection_config = sqlite_metadata_connection_config(\n",
" os.path.join(local_mlmd_folder, 'metadata.sqlite'))\n",
"\n",
"pipeline = ann_pipeline(\n",
" pipeline_name=PIPELINE_NAME,\n",
" pipeline_root=pipeline_root,\n",
" metadata_connection_config=metadata_connection_config,\n",
" project_id=PROJECT_ID,\n",
" project_number=PROJECT_NUMBER,\n",
" region=REGION,\n",
" vpc_name=VPC_NAME,\n",
" bq_dataset_name=bq_dataset_name,\n",
" index_display_name=index_display_name,\n",
" deployed_index_id_prefix=deployed_index_id_prefix,\n",
" min_item_frequency=min_item_frequency,\n",
" max_group_size=max_group_size,\n",
" dimensions=dimensions,\n",
" embeddings_gcs_location=embeddings_gcs_location\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Start the run"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logging.getLogger().setLevel(logging.INFO)\n",
"\n",
"BeamDagRunner().run(pipeline)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Inspect produced metadata\n",
"\n",
"During the execution of the pipeline, the inputs and outputs of each component have been tracked in ML Metadata. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ml_metadata import metadata_store\n",
"from ml_metadata.proto import metadata_store_pb2\n",
"\n",
"connection_config = metadata_store_pb2.ConnectionConfig()\n",
"connection_config.sqlite.filename_uri = os.path.join(local_mlmd_folder, 'metadata.sqlite')\n",
"connection_config.sqlite.connection_mode = 3 # READWRITE_OPENCREATE\n",
"store = metadata_store.MetadataStore(connection_config)\n",
"store.get_artifacts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# NOTICE. The following code does not work with ANN Service Experimental. It will be finalized when the service moves to the Preview stage."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running the pipeline on AI Platform Pipelines\n",
"\n",
"You will now run the pipeline on AI Platform Pipelines (Unified)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Package custom components into a container\n",
"\n",
"The modules containing custom components must be first package as a docker container image, which is a derivative of the standard TFX image."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create a Dockerfile"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile Dockerfile\n",
"FROM gcr.io/tfx-oss-public/tfx:0.25.0\n",
"WORKDIR /pipeline\n",
"COPY ./ ./\n",
"ENV PYTHONPATH=\"/pipeline:${PYTHONPATH}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Build and push the docker image to Container Registry"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!gcloud builds submit --tag gcr.io/{PROJECT_ID}/caip-tfx-custom:{USER} ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create AI Platform Pipelines client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from aiplatform.pipelines import client\n",
"\n",
"aipp_client = client.Client(\n",
" project_id=PROJECT_ID,\n",
" region=REGION,\n",
" api_key=API_KEY\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set the the parameters for AIPP execution and create the pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"metadata_connection_config = None\n",
"pipeline_root = PIPELINE_ROOT\n",
"\n",
"pipeline = ann_pipeline(\n",
" pipeline_name=PIPELINE_NAME,\n",
" pipeline_root=pipeline_root,\n",
" metadata_connection_config=metadata_connection_config,\n",
" project_id=PROJECT_ID,\n",
" project_number=PROJECT_NUMBER,\n",
" region=REGION,\n",
" vpc_name=VPC_NAME,\n",
" bq_dataset_name=bq_dataset_name,\n",
" index_display_name=index_display_name,\n",
" deployed_index_id_prefix=deployed_index_id_prefix,\n",
" min_item_frequency=min_item_frequency,\n",
" max_group_size=max_group_size,\n",
" dimensions=dimensions,\n",
" embeddings_gcs_location=embeddings_gcs_location\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Compile the pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = kubeflow_v2_dag_runner.KubeflowV2DagRunnerConfig(\n",
" project_id=PROJECT_ID,\n",
" display_name=PIPELINE_NAME,\n",
" default_image='gcr.io/{}/caip-tfx-custom:{}'.format(PROJECT_ID, USER))\n",
"runner = kubeflow_v2_dag_runner.KubeflowV2DagRunner(\n",
" config=config,\n",
" output_filename='pipeline.json')\n",
"runner.compile(\n",
" pipeline,\n",
" write_out=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit the pipeline run"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"aipp_client.create_run_from_job_spec('pipeline.json')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## License\n",
"\n",
"Copyright 2020 Google LLC\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n",
"\n",
"See the License for the specific language governing permissions and limitations under the License.\n",
"\n",
"**This is not an official Google product but sample code provided for an educational purpose**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"name": "tf2-gpu.2-4.m61",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-4:m61"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}