in 10_mlops/train_on_vertexai.py [0:0]
def main():
aiplatform.init(project=PROJECT, location=REGION, staging_bucket='gs://{}'.format(BUCKET))
# create data set
all_files = tf.io.gfile.glob('gs://{}/ch9/data/all*.csv'.format(BUCKET))
logging.info("Training on {}".format(all_files))
data_set = aiplatform.TabularDataset.create(
display_name='data-{}'.format(ENDPOINT_NAME),
gcs_source=all_files
)
if TF_VERSION is not None:
tf_version = TF_VERSION.replace(".", "-")
else:
tf_version = '2-' + tf.__version__[2:3]
# train
if AUTOML:
model = train_automl_model(data_set, TIMESTAMP, DEVELOP_MODE)
elif NUM_HPARAM_TRIALS > 1:
model = do_hyperparameter_tuning(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE, tf_version)
else:
model = train_custom_model(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE, tf_version)
# create endpoint if it doesn't already exist
endpoints = aiplatform.Endpoint.list(
filter='display_name="{}"'.format(ENDPOINT_NAME),
order_by='create_time desc',
project=PROJECT, location=REGION,
)
if len(endpoints) > 0:
endpoint = endpoints[0] # most recently created
else:
endpoint = aiplatform.Endpoint.create(
display_name=ENDPOINT_NAME, project=PROJECT, location=REGION,
sync=DEVELOP_MODE
)
# deploy
model.deploy(
endpoint=endpoint,
traffic_split={"0": 100},
machine_type='n1-standard-2',
min_replica_count=1,
max_replica_count=1,
sync=DEVELOP_MODE
)
if DEVELOP_MODE:
model.wait()