In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# <center>MySQL to Cloud Spanner Migration (or Bulk Load)
<table align="left">
<td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/dataproc-templates/blob/main/notebooks/mysql2spanner/MySqlToSpanner_notebook.ipynb">
        <img src="../images/colab-logo-32px.png" alt="Colab logo" />Run in Colab
    </a>
</td>
<td>
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fdataproc-templates%2Fmain%2Fnotebooks%2Fmysql2spanner%2FMySqlToSpanner_notebook.ipynb">
        <img src="../images/colab-enterprise-logo-32px.png" alt="GCP Colab Enterprise logo" />Run in Colab Enterprise
    </a>
</td>
<td>
    <a href="https://github.com/GoogleCloudPlatform/dataproc-templates/blob/main/notebooks/mysql2spanner/MySqlToSpanner_notebook.ipynb">
        <img src="../images/github-logo-32px.png" alt="GitHub logo" />View on GitHub
    </a>
</td>
<td>
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/dataproc-templates/main/notebooks/mysql2spanner/MySqlToSpanner_notebook.ipynb">
        <img src="../images/vertexai.jpg" alt="Vertex AI logo" />Open in Vertex AI Workbench
    </a>
</td>
</table>

#### References

- [DataprocPySparkBatchOp reference](https://google-cloud-pipeline-components.readthedocs.io/en/google-cloud-pipeline-components-1.0.0/google_cloud_pipeline_components.experimental.dataproc.html)
- [Kubeflow SDK Overview](https://www.kubeflow.org/docs/components/pipelines/sdk/sdk-overview/)
- [Dataproc Serverless in Vertex AI Pipelines tutorial](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/ml_ops/stage3/get_started_with_dataproc_serverless_pipeline_components.ipynb)
- [Build a Vertex AI Pipeline](https://cloud.google.com/vertex-ai/docs/pipelines/build-pipeline)

This notebook is built to run a Vertex AI User-Managed Notebook using the default Compute Engine Service Account.  
Check the Dataproc Serverless in Vertex AI Pipelines tutorial linked above to learn how to setup a different Service Account.  

#### Permissions

Make sure that the service account used to run the notebook has the following roles:

- roles/aiplatform.serviceAgent
- roles/aiplatform.customCodeServiceAgent
- roles/storage.objectCreator
- roles/storage.objectViewer
- roles/dataproc.editor
- roles/dataproc.worker

## Step 1: Install Libraries
#### Run Step 1 one time for each new notebook instance

In [None]:
!pip3 install pymysql SQLAlchemy
!pip3 install --upgrade google-cloud-pipeline-components kfp --user -q
!pip3 install google-cloud-spanner
!pip3 install --upgrade google-cloud-storage

In [None]:
!sudo apt-get update -y
!sudo apt-get install default-jdk -y
!sudo apt-get install maven -y

#### Once you've installed the additional packages, you may need to restart the notebook kernel so it can find the packages.

Uncomment & Run this cell if you have installed anything from above commands

In [None]:
# import os
# import IPython
# if not os.getenv("IS_TESTING"):
#     app = IPython.Application.instance()
#     app.kernel.do_shutdown(True)

Uncomment & Run this cell if you are running from Colab

In [None]:
# !git clone https://github.com/GoogleCloudPlatform/dataproc-templates.git
# !mv /content/dataproc-templates/notebooks/util /content/
# !mv /content/dataproc-templates/java/ /content/

## Step 2: Import Libraries

In [None]:
from datetime import datetime
import os
from pathlib import Path
import sys
import time

import google.cloud.aiplatform as aiplatform
from kfp import dsl
from kfp import compiler

try:
    from google_cloud_pipeline_components.experimental.dataproc import DataprocSparkBatchOp
except ModuleNotFoundError:
    from google_cloud_pipeline_components.v1.dataproc import DataprocSparkBatchOp
    
import pandas as pd
import pymysql
import sqlalchemy

module_path = os.path.abspath(os.pardir)
if module_path not in sys.path:
    sys.path.append(module_path)

from util.jdbc.jdbc_input_manager import JDBCInputManager
from util.jdbc import jdbc_input_manager_interface
from util import notebook_functions

## Step 3: Assign Parameters

### Step 3.1 Common Parameters
 
- PROJECT : GCP project-id
- REGION : GCP region (us-central1)
- GCS_STAGING_LOCATION : Cloud Storage staging location to be used for this notebook to store artifacts 
- SUBNET : VPC subnet
- JARS : list of jars. For this notebook mysql connector and avro jar is required in addition with the dataproc template jars
- MAX_PARALLELISM : Parameter for number of jobs to run in parallel default value is 2
- DATAPROC_SERVICE_ACCOUNT : Service account which will run serverless dataproc batch job

In [None]:
IS_PARAMETERIZED = False

In [None]:
if not IS_PARAMETERIZED:
    PROJECT = ""
    REGION = "" # eg: us-central1 (any valid GCP region)
    GCS_STAGING_LOCATION = "" # eg: gs://my-staging-bucket/sub-folder
    SUBNET = "projects/{project}/regions/{region}/subnetworks/{subnet}"
    MAX_PARALLELISM = 5 # max number of tables which will migrated parallelly 
    DATAPROC_SERVICE_ACCOUNT = "" # eg: test@project_id.iam.gserviceaccount.com

# Do not change this parameter unless you want to refer below JARS from new location
JARS = [GCS_STAGING_LOCATION + "/jars/mysql-connector-java-8.0.29.jar","file:///usr/lib/spark/external/spark-avro.jar"]

### Step 3.2 MySQL to Spanner Parameters
- MYSQL_HOST: MySQL instance ip address
- MYSQL_PORT: MySQL instance port
- MYSQL_USERNAME: MySQL username
- MYSQL_PASSWORD: MySQL password
- MYSQL_DATABASE: Name of database that you want to migrate
- MYSQL_TABLE_LIST: List of tables you want to migrate e.g.: ['table1','table2'] else provide an empty list for migration whole database e.g.: [] 
- MYSQL_READ_PARTITION_COLUMNS: Dictionary of custom read partition columns, e.g.: {'table2': 'secondary_id'}
- MYSQL_OUTPUT_SPANNER_MODE: Output mode for MySQL data one of (overwrite|append). Use append if schema already exists in Spanner
- SPANNER_INSTANCE: Cloud Spanner instance name
- SPANNER_DATABASE: Cloud Spanner database name

Spanner requires primary key for each table
- SPANNER_TABLE_PRIMARY_KEYS: Dictionary of format {"table_name": "primary_key_column1,primary_key_column2"} for tables which do not have primary key in MySQL

In [None]:
if not IS_PARAMETERIZED:
    MYSQL_HOST = ""
    MYSQL_PORT = "3306"
    MYSQL_USERNAME = ""
    MYSQL_PASSWORD = ""
    MYSQL_DATABASE = ""
    MYSQL_TABLE_LIST = [] # Leave list empty for migrating complete database else provide tables as ['table1','table2']
    MYSQL_READ_PARTITION_COLUMNS = {} # Leave empty for default read partition columns
    MYSQL_OUTPUT_SPANNER_MODE = "overwrite" # one of overwrite|append (Use append when schema already exists in Spanner)

    SPANNER_INSTANCE = ""
    SPANNER_DATABASE = ""
    SPANNER_TABLE_PRIMARY_KEYS = {} # Provide tables which do not have PK in MySQL {"table_name":"primary_key_column1,primary_key_column2"}

### Step 3.3 Notebook Configuration Parameters
Below variables should not be changed unless required

In [None]:
cur_path = Path(os.getcwd())

if IS_PARAMETERIZED:
    WORKING_DIRECTORY = os.path.join(cur_path.parent ,'java')
else:
    WORKING_DIRECTORY = os.path.join(cur_path.parent.parent ,'java')

# If the above code doesn't fetches the correct path please
# provide complete path to python folder in your dataproc 
# template repo which you cloned 

# WORKING_DIRECTORY = "/home/jupyter/dataproc-templates/java/"
print(WORKING_DIRECTORY)

In [None]:
PYMYSQL_DRIVER = "mysql+pymysql"
JDBC_DRIVER = "com.mysql.cj.jdbc.Driver"
JDBC_URL = "jdbc:mysql://{}:{}/{}?user={}&password={}".format(MYSQL_HOST,MYSQL_PORT,MYSQL_DATABASE,MYSQL_USERNAME,MYSQL_PASSWORD)
MAIN_CLASS = "com.google.cloud.dataproc.templates.main.DataProcTemplate"
JAR_FILE = "dataproc-templates-1.0-SNAPSHOT.jar"
LOG4J_PROPERTIES_PATH = "./src/test/resources"
LOG4J_PROPERTIES = "log4j-spark-driver-template.properties"
PIPELINE_ROOT = GCS_STAGING_LOCATION + "/pipeline_root/dataproc_pyspark"

# Adding Dataproc template JAR
JARS.append(GCS_STAGING_LOCATION + "/" + JAR_FILE)

## Step 4: Generate MySQL Table List
This step creates list of tables for migration. If MYSQL_TABLE_LIST is empty then all the tables in the MYSQL_DATABASE are listed for migration otherwise the provided list is used

In [None]:
DB = sqlalchemy.create_engine(
    sqlalchemy.engine.url.URL.create(
        drivername=PYMYSQL_DRIVER,
        username=MYSQL_USERNAME,
        password=MYSQL_PASSWORD,
        database=MYSQL_DATABASE,
        host=MYSQL_HOST,
        port=MYSQL_PORT
    )
)
input_mgr = JDBCInputManager.create("mysql", DB)

# Retrieve list of tables from database.
MYSQL_TABLE_LIST = input_mgr.build_table_list(schema_filter=MYSQL_DATABASE, table_filter=MYSQL_TABLE_LIST)
print(f"Total tables to migrate from schema {MYSQL_DATABASE}:", len(MYSQL_TABLE_LIST))
    
print("List of tables for migration:")
print(MYSQL_TABLE_LIST)

## Step 5: Get Primary Keys for Tables Not Present in SPANNER_TABLE_PRIMARY_KEYS
For tables which do not have primary key provided in dictionary SPANNER_TABLE_PRIMARY_KEYS this step fetches primary key from MYSQL_DATABASE

In [None]:
for table_name, pk_columns in input_mgr.get_primary_keys().items():
    notebook_functions.update_spanner_primary_keys(SPANNER_TABLE_PRIMARY_KEYS, table_name, pk_columns)

notebook_functions.remove_unexpected_spanner_primary_keys(SPANNER_TABLE_PRIMARY_KEYS, MYSQL_TABLE_LIST)

In [None]:
pkDF = pd.DataFrame({"table" : MYSQL_TABLE_LIST,
                     "primary_keys": [SPANNER_TABLE_PRIMARY_KEYS.get(_) for _ in MYSQL_TABLE_LIST]})
print("Below are identified primary keys for migrating MySQL table to Spanner:")
pkDF

## Step 6 Identify Read Partition Columns
This step uses PARTITION_THRESHOLD (default value is 1 million) parameter and any table having rows greater than PARTITION_THRESHOLD will be used for partitioned read based on Primary Keys
 - PARTITION_OPTIONS: List will have table and its partitioned column and Spark SQL settings if exceeds the threshold

In [None]:
PARTITION_THRESHOLD = 200000 # Number of rows fetched per spark executor
PARTITION_OPTIONS = input_mgr.define_read_partitioning(
    PARTITION_THRESHOLD, custom_partition_columns=MYSQL_READ_PARTITION_COLUMNS
)
input_mgr.read_partitioning_df(PARTITION_OPTIONS)

## Step 7: Calculate Parallel Jobs for MySQL to Cloud Spanner
This step uses MAX_PARALLELISM parameter to calculate number of parallel jobs to run

In [None]:
# Calculate parallel jobs:
JOB_LIST = notebook_functions.split_list(input_mgr.get_table_list(), MAX_PARALLELISM)
print("List of tables for execution:")
print(JOB_LIST)

## Step 8: Create JAR files and Upload to Cloud Storage
#### Run Step 8 one time for each new notebook instance

In [None]:
%cd $WORKING_DIRECTORY

#### Setting PATH variables for JDK and Maven and executing MAVEN build

In [None]:
!wget https://downloads.mysql.com/archives/get/p/3/file/mysql-connector-java-8.0.29.tar.gz
!tar -xf mysql-connector-java-8.0.29.tar.gz
!mvn clean spotless:apply install -DskipTests 

#### Copying JARS Files to GCS_STAGING_LOCATION

In [None]:
!gsutil cp target/$JAR_FILE $GCS_STAGING_LOCATION/$JAR_FILE
!gsutil cp $LOG4J_PROPERTIES_PATH/$LOG4J_PROPERTIES $GCS_STAGING_LOCATION/$LOG4J_PROPERTIES
!gsutil cp mysql-connector-java-8.0.29/mysql-connector-java-8.0.29.jar $GCS_STAGING_LOCATION/jars/mysql-connector-java-8.0.29.jar

## Step 9: Execute Pipeline to Migrate Tables from MySQL to Spanner

In [None]:
mysql_to_spanner_jobs = []

In [None]:
def migrate_mysql_to_spanner(EXECUTION_LIST):
    EXECUTION_LIST = EXECUTION_LIST
    aiplatform.init(project=PROJECT,staging_bucket=GCS_STAGING_LOCATION)
    
    @dsl.pipeline(
        name="java-mysql-to-spanner-spark",
        description="Pipeline to get data from MySQL to Cloud Spanner",
    )
    def pipeline(
        PROJECT_ID: str = PROJECT,
        LOCATION: str = REGION,
        MAIN_CLASS: str = MAIN_CLASS,
        JAR_FILE_URIS: list = JARS,
        SUBNETWORK_URI: str = SUBNET,
        FILE_URIS: list = [GCS_STAGING_LOCATION + "/" + LOG4J_PROPERTIES]
    ):
        for table in EXECUTION_LIST:
            BATCH_ID = "mysql2spanner-{}-{}".format(table,datetime.now().strftime("%s")).replace('_','-').lower()
            mysql_to_spanner_jobs.append(BATCH_ID)
            if table in PARTITION_OPTIONS.keys():
                partition_options = PARTITION_OPTIONS[table]
                TEMPLATE_SPARK_ARGS = [
                "--template=JDBCTOSPANNER",
                "--templateProperty", "project.id={}".format(PROJECT),
                "--templateProperty", "jdbctospanner.jdbc.url={}".format(JDBC_URL),
                "--templateProperty", "jdbctospanner.jdbc.driver.class.name={}".format(JDBC_DRIVER),
                "--templateProperty", "jdbctospanner.sql=select * from {}".format(table),
                "--templateProperty", "jdbctospanner.output.instance={}".format(SPANNER_INSTANCE),
                "--templateProperty", "jdbctospanner.output.database={}".format(SPANNER_DATABASE),
                "--templateProperty", "jdbctospanner.output.table={}".format(table),
                "--templateProperty", "jdbctospanner.output.saveMode={}".format(MYSQL_OUTPUT_SPANNER_MODE.capitalize()),
                "--templateProperty", "jdbctospanner.output.primaryKey={}".format(SPANNER_TABLE_PRIMARY_KEYS[table]),
                "--templateProperty", "jdbctospanner.output.batchInsertSize=200",
                "--templateProperty", "jdbctospanner.sql.partitionColumn={}".format(partition_options[jdbc_input_manager_interface.SPARK_PARTITION_COLUMN]),
                "--templateProperty", "jdbctospanner.sql.lowerBound={}".format(partition_options[jdbc_input_manager_interface.SPARK_LOWER_BOUND]),
                "--templateProperty", "jdbctospanner.sql.upperBound={}".format(partition_options[jdbc_input_manager_interface.SPARK_UPPER_BOUND]),
                "--templateProperty", "jdbctospanner.sql.numPartitions={}".format(partition_options[jdbc_input_manager_interface.SPARK_NUM_PARTITIONS]),
                ]
            else:
                TEMPLATE_SPARK_ARGS = [
                "--template=JDBCTOSPANNER",
                "--templateProperty", "project.id={}".format(PROJECT),
                "--templateProperty", "jdbctospanner.jdbc.url={}".format(JDBC_URL),
                "--templateProperty", "jdbctospanner.jdbc.driver.class.name={}".format(JDBC_DRIVER),
                "--templateProperty", "jdbctospanner.sql=select * from {}".format(table),
                "--templateProperty", "jdbctospanner.output.instance={}".format(SPANNER_INSTANCE),
                "--templateProperty", "jdbctospanner.output.database={}".format(SPANNER_DATABASE),
                "--templateProperty", "jdbctospanner.output.table={}".format(table),
                "--templateProperty", "jdbctospanner.output.saveMode={}".format(MYSQL_OUTPUT_SPANNER_MODE.capitalize()),
                "--templateProperty", "jdbctospanner.output.primaryKey={}".format(SPANNER_TABLE_PRIMARY_KEYS[table]),
                "--templateProperty", "jdbctospanner.output.batchInsertSize=200",
                ]
            _ = DataprocSparkBatchOp(
                project=PROJECT_ID,
                location=LOCATION,
                batch_id=BATCH_ID,
                main_class=MAIN_CLASS,
                jar_file_uris=JAR_FILE_URIS,
                file_uris=FILE_URIS,
                subnetwork_uri=SUBNETWORK_URI,
                runtime_config_version="1.1", # issue 665
                service_account=DATAPROC_SERVICE_ACCOUNT,
                args=TEMPLATE_SPARK_ARGS
            )
            time.sleep(1)

    compiler.Compiler().compile(pipeline_func=pipeline, package_path="pipeline.json")

    pipeline = aiplatform.PipelineJob(
        display_name="pipeline",
        template_path="pipeline.json",
        pipeline_root=PIPELINE_ROOT,
        enable_caching=False,
        location=REGION,
    )
    # run() method has an optional parameter `service_account` which you can pass if you want to run pipeline using
    # specific service account instead of default service account 
    # eg. pipeline.run(service_account='test@project_id.iam.gserviceaccount.com')
    pipeline.run()

In [None]:
for execution_list in JOB_LIST:
    print(execution_list)
    migrate_mysql_to_spanner(execution_list)

## Step 10: Get status for tables migrated from MySql to Spanner

In [None]:
def get_bearer_token():
    
    try:
        #Defining Scope
        CREDENTIAL_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]

        #Assining credentials and project value
        credentials, project_id = google.auth.default(scopes=CREDENTIAL_SCOPES)

        #Refreshing credentials data
        credentials.refresh(requests.Request())

        #Get refreshed token
        token = credentials.token
        if token:
            return (token,200)
        else:
            return "Bearer token not generated"
    except Exception as error:
        return ("Bearer token not generated. Error : {}".format(error),500)

In [None]:
from google.auth.transport import requests
import google
token = get_bearer_token()
if token[1] == 200:
    print("Bearer token generated")
else:
    print(token)

In [None]:
import requests

mysql_to_spanner_status = []
job_status_url = "https://dataproc.googleapis.com/v1/projects/{}/locations/{}/batches/{}"
for job in mysql_to_spanner_jobs:
    auth = "Bearer " + token[0]
    url = job_status_url.format(PROJECT,REGION,job)
    headers = {
      'Content-Type': 'application/json; charset=UTF-8',
      'Authorization': auth 
    }
    response = requests.get(url, headers=headers)
    mysql_to_spanner_status.append(response.json()['state'])

In [None]:
statusDF = pd.DataFrame({"table": MYSQL_TABLE_LIST, "mysql_to_spanner_job" : mysql_to_spanner_jobs, "mysql_to_spanner_status": mysql_to_spanner_status})
statusDF

## Step 11: Validate Row Counts of Migrated Tables from MySQL to Cloud Spanner

In [None]:
# Get MySQL table counts
mysql_row_count = input_mgr.get_table_list_with_counts()

In [None]:
# Get Cloud Spanner table counts
spanner_row_count = []
from google.cloud import spanner

spanner_client = spanner.Client()
instance = spanner_client.instance(SPANNER_INSTANCE)
database = instance.database(SPANNER_DATABASE)

for table in MYSQL_TABLE_LIST:
    with database.snapshot() as snapshot:
        qry = "@{{USE_ADDITIONAL_PARALLELISM=true}} select count(1) from {}".format(table)
        results = snapshot.execute_sql(qry)
        for row in results:
            spanner_row_count.append(row[0])

In [None]:
statusDF['mysql_row_count'] = mysql_row_count 
statusDF['spanner_row_count'] = spanner_row_count 
statusDF

## Post data loading activities
- You may create relationships (FKs), constraints and indexes (as needed).
- You may configure countinuous replication with [DataStream](https://cloud.google.com/datastream/docs/configure-your-source-mysql-database) or any other 3rd party tools.