scripts/run.py (126 lines of code) (raw):

# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A utility to submit a Vertex Training T5X job.""" import gcsfs import fsspec import os from absl import flags from absl import app from absl import logging from datetime import datetime from importlib import import_module from google.cloud import aiplatform as vertex_ai from google.cloud.aiplatform import CustomJob from typing import List from typing import Dict from typing import Any from typing import Union from typing import Optional MACHINE_TYPE = 'cloud-tpu' flags.DEFINE_string('project_id', None, 'GCP Project') flags.DEFINE_string('region', None, 'Vertex Pipelines region') flags.DEFINE_string('staging_bucket', None, 'Staging bucket') flags.DEFINE_string('training_sa', None, 'Training SA') flags.DEFINE_string('image_uri', None, 'Training image') flags.DEFINE_string('job_name_prefix', 't5x_job', 'Job name prefix') flags.DEFINE_list('gin_files', None, 'Gin configuration files') flags.DEFINE_list('gin_overwrites', None, 'Gin overwrites') flags.DEFINE_list('gin_search_paths', None, 'Gin search paths') flags.DEFINE_string('accelerator_type', 'TPU_V2', 'Accelerator type') flags.DEFINE_integer('accelerator_count', 8, 'Number of cores') flags.DEFINE_string('run_mode', 'train', 'Run mode') flags.DEFINE_string('tfds_data_dir', None, 'TFDS data dir') flags.DEFINE_bool('sync', True, 'Execute synchronously') flags.mark_flag_as_required('project_id') flags.mark_flag_as_required('region') flags.mark_flag_as_required('staging_bucket') flags.mark_flag_as_required('gin_files') flags.mark_flag_as_required('image_uri') FLAGS = flags.FLAGS def _create_t5x_custom_job( display_name: str, machine_type: str, accelerator_type: str, accelerator_count: int, image_uri: str, run_mode: str, gin_files: List[str], model_dir: str, gin_search_paths: Optional[List[str]] = None, tfds_data_dir: Optional[str] = None, replica_count: int = 1, gin_overwrites: Optional[List[str]] = None, base_output_dir: Optional[str] = None, ) -> CustomJob: """Creates a Vertex AI custom T5X training job. It copies the configuration files (.gin) to GCS, creates a worker_pool_spec and returns an aiplatform.CustomJob. Args: display_name (str): Required. User defined display name for the Vertex AI custom T5X job. machine_type (str): Required. The type of machine for running the custom training job on dedicated resources. For TPUs, use `cloud-tpu`. accelerator_type (str): Required. The type of accelerator(s) that may be attached to the machine as per `accelerator_count`. Only used if `machine_type` is set. Options: `TPU_V2` or `TPU_V3`. accelerator_count (int): Required. The number of accelerators to attach to the `machine_type`. Only used if `machine_type` is set. For TPUs, this is the number of cores to be provisioned. Example: 8, 128, 512, etc. image_uri (str): Required. Full image path to be used as the execution environment of the custom T5X training job. Example: 'gcr.io/{PROJECT_ID}/{IMAGE_NAME}' run_mode (str): Required. The mode to run T5X under. Options: `train`, `eval`, `infer`. gin_files (List[str]): Required. Full path to gin configuration file on local filesystem. Multiple paths may be passed and will be imported in the given order, with later configurations overriding earlier ones. gin_search_paths (List[str]): List of gin config path prefixes to be prepended to gin suffixes in gin includes and gin_files model_dir (str): Required. Path on Google Cloud Storage to store all the artifacts generated by the custom T5X training job. The path must be in this format: `gs://{bucket name}/{your folder}/...`. Example: gs://my_bucket/experiments/model1/ tfds_data_dir (Optional[str] = None): Optional. If set, this directory will be used to store datasets prepared by TensorFlow Datasets that are not available in the public TFDS GCS bucket. Note that this flag overrides the `tfds_data_dir` attribute of all Task`s. This path must be a valid GCS path. Example: gs://my_bucket/datasets/my_dataset/ replica_count (int = 1): Optional. The number of worker replicas. If replica count = 1 then one chief replica will be provisioned. For TPUs this must be set to 1. gin_overwrites (Optional[List[str]] = None): Optional. List of arguments to overwrite gin configurations. Argument must be enclosed in parentheses. Example: --gin.TRAIN_PATH=\"gs://my_bucket/folder\" base_output_dir (Optional[str] = None): Returns: (aiplatform.CustomJob): Return an instance of a Vertex AI training CustomJob. """ local_fs = fsspec.filesystem('file') gcs_fs = gcsfs.GCSFileSystem() # Check if gin files exists if not gin_files or not all([local_fs.isfile(f) for f in gin_files]): raise FileNotFoundError( 'Provide a list of valid gin files.' ) # Try to copy files to GCS bucket try: gcs_gin_files = [] for gin_file in gin_files: gcs_path = os.path.join(model_dir, gin_file.split(sep='/')[-1]) gcs_fs.put(gin_file, gcs_path) gcs_gin_files.append(gcs_path.replace('gs://', '/gcs/')) except: raise RuntimeError('Could not copy gin files to GCS.') container_spec = {"image_uri": image_uri} # Temporary mitigation to address issues with t5x/main.py # and inference on tfrecord files if run_mode == 'infer': args = [ f'--gin.MODEL_DIR="{model_dir}"', f'--tfds_data_dir={tfds_data_dir}', ] container_spec['command'] = ["python", "./t5x/t5x/infer.py"] else: args = [ f'--run_mode={run_mode}', f'--gin.MODEL_DIR="{model_dir}"', f'--tfds_data_dir={tfds_data_dir}', ] if gin_search_paths: args.append(f'--gin_search_paths={",".join(gin_search_paths)}') args += [f'--gin_file={gcs_path}' for gcs_path in gcs_gin_files] if gin_overwrites: args += [f'--gin.{overwrite}' for overwrite in gin_overwrites] container_spec["args"] = args worker_pool_specs = [ { "machine_spec": { "machine_type": machine_type, "accelerator_type": accelerator_type, "accelerator_count": accelerator_count, }, "replica_count": replica_count, "container_spec": container_spec, } ] job = vertex_ai.CustomJob( display_name=display_name, worker_pool_specs=worker_pool_specs, base_output_dir=base_output_dir ) return job def _main(argv): vertex_ai.init( project=FLAGS.project_id, location=FLAGS.region, staging_bucket=FLAGS.staging_bucket, ) job_name = f'{FLAGS.job_name_prefix}_{datetime.now().strftime("%Y%m%d%H%M%S")}' job_dir = f'{FLAGS.staging_bucket}/t5x_jobs/{job_name}' job = _create_t5x_custom_job( display_name=job_name, machine_type=MACHINE_TYPE, accelerator_type=FLAGS.accelerator_type, accelerator_count=FLAGS.accelerator_count, image_uri=FLAGS.image_uri, run_mode=FLAGS.run_mode, gin_files=FLAGS.gin_files, gin_overwrites=FLAGS.gin_overwrites, gin_search_paths=FLAGS.gin_search_paths, tfds_data_dir=FLAGS.tfds_data_dir, model_dir=job_dir, ) logging.info(f'Starting job: {job_name}') job.run(sync=FLAGS.sync) if __name__ == "__main__": app.run(_main)