def main()

in 10_mlops/train_on_vertexai.py [0:0]


def main():
    aiplatform.init(project=PROJECT, location=REGION, staging_bucket='gs://{}'.format(BUCKET))

    # create data set
    all_files = tf.io.gfile.glob('gs://{}/ch9/data/all*.csv'.format(BUCKET))
    logging.info("Training on {}".format(all_files))
    data_set = aiplatform.TabularDataset.create(
        display_name='data-{}'.format(ENDPOINT_NAME),
        gcs_source=all_files
    )
    if TF_VERSION is not None:
        tf_version = TF_VERSION.replace(".", "-")
    else:
        tf_version = '2-' + tf.__version__[2:3]

    # train
    if AUTOML:
        model = train_automl_model(data_set, TIMESTAMP, DEVELOP_MODE)
    elif NUM_HPARAM_TRIALS > 1:
        model = do_hyperparameter_tuning(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE, tf_version)
    else:
        model = train_custom_model(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE, tf_version)

    # create endpoint if it doesn't already exist
    endpoints = aiplatform.Endpoint.list(
        filter='display_name="{}"'.format(ENDPOINT_NAME),
        order_by='create_time desc',
        project=PROJECT, location=REGION,
    )
    if len(endpoints) > 0:
        endpoint = endpoints[0]  # most recently created
    else:
        endpoint = aiplatform.Endpoint.create(
            display_name=ENDPOINT_NAME, project=PROJECT, location=REGION,
            sync=DEVELOP_MODE
        )

    # deploy
    model.deploy(
        endpoint=endpoint,
        traffic_split={"0": 100},
        machine_type='n1-standard-2',
        min_replica_count=1,
        max_replica_count=1,
        sync=DEVELOP_MODE
    )

    if DEVELOP_MODE:
        model.wait()