scripts/run.py (126 lines of code) (raw):
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A utility to submit a Vertex Training T5X job."""
import gcsfs
import fsspec
import os
from absl import flags
from absl import app
from absl import logging
from datetime import datetime
from importlib import import_module
from google.cloud import aiplatform as vertex_ai
from google.cloud.aiplatform import CustomJob
from typing import List
from typing import Dict
from typing import Any
from typing import Union
from typing import Optional
MACHINE_TYPE = 'cloud-tpu'
flags.DEFINE_string('project_id', None, 'GCP Project')
flags.DEFINE_string('region', None, 'Vertex Pipelines region')
flags.DEFINE_string('staging_bucket', None, 'Staging bucket')
flags.DEFINE_string('training_sa', None, 'Training SA')
flags.DEFINE_string('image_uri', None, 'Training image')
flags.DEFINE_string('job_name_prefix', 't5x_job', 'Job name prefix')
flags.DEFINE_list('gin_files', None, 'Gin configuration files')
flags.DEFINE_list('gin_overwrites', None, 'Gin overwrites')
flags.DEFINE_list('gin_search_paths', None, 'Gin search paths')
flags.DEFINE_string('accelerator_type', 'TPU_V2', 'Accelerator type')
flags.DEFINE_integer('accelerator_count', 8, 'Number of cores')
flags.DEFINE_string('run_mode', 'train', 'Run mode')
flags.DEFINE_string('tfds_data_dir', None, 'TFDS data dir')
flags.DEFINE_bool('sync', True, 'Execute synchronously')
flags.mark_flag_as_required('project_id')
flags.mark_flag_as_required('region')
flags.mark_flag_as_required('staging_bucket')
flags.mark_flag_as_required('gin_files')
flags.mark_flag_as_required('image_uri')
FLAGS = flags.FLAGS
def _create_t5x_custom_job(
display_name: str,
machine_type: str,
accelerator_type: str,
accelerator_count: int,
image_uri: str,
run_mode: str,
gin_files: List[str],
model_dir: str,
gin_search_paths: Optional[List[str]] = None,
tfds_data_dir: Optional[str] = None,
replica_count: int = 1,
gin_overwrites: Optional[List[str]] = None,
base_output_dir: Optional[str] = None,
) -> CustomJob:
"""Creates a Vertex AI custom T5X training job.
It copies the configuration files (.gin) to GCS, creates a worker_pool_spec
and returns an aiplatform.CustomJob.
Args:
display_name (str):
Required. User defined display name for the Vertex AI custom T5X job.
machine_type (str):
Required. The type of machine for running the custom training job on
dedicated resources. For TPUs, use `cloud-tpu`.
accelerator_type (str):
Required. The type of accelerator(s) that may be attached
to the machine as per `accelerator_count`. Only used if
`machine_type` is set. Options: `TPU_V2` or `TPU_V3`.
accelerator_count (int):
Required. The number of accelerators to attach to the `machine_type`.
Only used if `machine_type` is set. For TPUs, this is the number of
cores to be provisioned.
Example: 8, 128, 512, etc.
image_uri (str):
Required. Full image path to be used as the execution environment of the
custom T5X training job.
Example:
'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'
run_mode (str):
Required. The mode to run T5X under. Options: `train`, `eval`, `infer`.
gin_files (List[str]):
Required. Full path to gin configuration file on local filesystem.
Multiple paths may be passed and will be imported in the given
order, with later configurations overriding earlier ones.
gin_search_paths (List[str]):
List of gin config path prefixes to be prepended to gin suffixes in gin includes and gin_files
model_dir (str):
Required. Path on Google Cloud Storage to store all the artifacts generated
by the custom T5X training job. The path must be in this format:
`gs://{bucket name}/{your folder}/...`.
Example:
gs://my_bucket/experiments/model1/
tfds_data_dir (Optional[str] = None):
Optional. If set, this directory will be used to store datasets prepared by
TensorFlow Datasets that are not available in the public TFDS GCS
bucket. Note that this flag overrides the `tfds_data_dir` attribute of
all Task`s. This path must be a valid GCS path.
Example:
gs://my_bucket/datasets/my_dataset/
replica_count (int = 1):
Optional. The number of worker replicas. If replica count = 1 then one chief
replica will be provisioned. For TPUs this must be set to 1.
gin_overwrites (Optional[List[str]] = None):
Optional. List of arguments to overwrite gin configurations. Argument must be
enclosed in parentheses.
Example:
--gin.TRAIN_PATH=\"gs://my_bucket/folder\"
base_output_dir (Optional[str] = None):
Returns:
(aiplatform.CustomJob):
Return an instance of a Vertex AI training CustomJob.
"""
local_fs = fsspec.filesystem('file')
gcs_fs = gcsfs.GCSFileSystem()
# Check if gin files exists
if not gin_files or not all([local_fs.isfile(f) for f in gin_files]):
raise FileNotFoundError(
'Provide a list of valid gin files.'
)
# Try to copy files to GCS bucket
try:
gcs_gin_files = []
for gin_file in gin_files:
gcs_path = os.path.join(model_dir, gin_file.split(sep='/')[-1])
gcs_fs.put(gin_file, gcs_path)
gcs_gin_files.append(gcs_path.replace('gs://', '/gcs/'))
except:
raise RuntimeError('Could not copy gin files to GCS.')
container_spec = {"image_uri": image_uri}
# Temporary mitigation to address issues with t5x/main.py
# and inference on tfrecord files
if run_mode == 'infer':
args = [
f'--gin.MODEL_DIR="{model_dir}"',
f'--tfds_data_dir={tfds_data_dir}',
]
container_spec['command'] = ["python", "./t5x/t5x/infer.py"]
else:
args = [
f'--run_mode={run_mode}',
f'--gin.MODEL_DIR="{model_dir}"',
f'--tfds_data_dir={tfds_data_dir}',
]
if gin_search_paths:
args.append(f'--gin_search_paths={",".join(gin_search_paths)}')
args += [f'--gin_file={gcs_path}' for gcs_path in gcs_gin_files]
if gin_overwrites:
args += [f'--gin.{overwrite}' for overwrite in gin_overwrites]
container_spec["args"] = args
worker_pool_specs = [
{
"machine_spec": {
"machine_type": machine_type,
"accelerator_type": accelerator_type,
"accelerator_count": accelerator_count,
},
"replica_count": replica_count,
"container_spec": container_spec,
}
]
job = vertex_ai.CustomJob(
display_name=display_name,
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir
)
return job
def _main(argv):
vertex_ai.init(
project=FLAGS.project_id,
location=FLAGS.region,
staging_bucket=FLAGS.staging_bucket,
)
job_name = f'{FLAGS.job_name_prefix}_{datetime.now().strftime("%Y%m%d%H%M%S")}'
job_dir = f'{FLAGS.staging_bucket}/t5x_jobs/{job_name}'
job = _create_t5x_custom_job(
display_name=job_name,
machine_type=MACHINE_TYPE,
accelerator_type=FLAGS.accelerator_type,
accelerator_count=FLAGS.accelerator_count,
image_uri=FLAGS.image_uri,
run_mode=FLAGS.run_mode,
gin_files=FLAGS.gin_files,
gin_overwrites=FLAGS.gin_overwrites,
gin_search_paths=FLAGS.gin_search_paths,
tfds_data_dir=FLAGS.tfds_data_dir,
model_dir=job_dir,
)
logging.info(f'Starting job: {job_name}')
job.run(sync=FLAGS.sync)
if __name__ == "__main__":
app.run(_main)