in community-content/vertex_model_garden/model_oss/peft/train/vmg/train_entrypoint.py [0:0]
def main(unused_argv: Sequence[str]) -> None:
parser = argparse.ArgumentParser()
parser.add_argument('--config_file')
parser.add_argument('--task')
parser.add_argument('--gcs_rsync_interval_secs', type=int, default=60)
args, unknown = parser.parse_known_args()
task = args.task
dirs_to_sync = None
if task in _TEXT_TO_IMAGE_TASKS_SCRIPTS:
# Setup accelerate config before running trainer.
config_gen_cmd = [
'python',
'-c',
(
'from accelerate.utils import write_basic_config;'
' write_basic_config(mixed_precision="fp16")'
),
]
task_cmd = [
'accelerate',
'launch',
_TEXT_TO_IMAGE_TASKS_SCRIPTS[task],
] + list(map(dataset_validation_util.force_gcs_fuse_path, unknown))
commands = [config_gen_cmd, task_cmd]
elif task in [constants.INSTRUCT_LORA]:
commands, dirs_to_sync = _get_train_and_maybe_merge_cmd_and_dirs_to_sync(
task_type=task, config_file=args.config_file, unknown=unknown
)
elif task in [constants.MERGE_CAUSAL_LANGUAGE_MODEL_LORA]:
commands, dirs_to_sync = _get_merge_cmd_and_dirs_to_sync(
task_type=task, config_file=args.config_file, unknown=unknown
)
else:
assert task in _TASK_TO_SCRIPT
cmd = launch_script_cmd(_TASK_TO_SCRIPT[task], args.config_file)
cmd.extend(unknown)
commands = [cmd]
rsync_process = None
mp_queue = multiprocessing.Queue(maxsize=1)
if dirs_to_sync:
rsync_process = gcs_syncer.setup_gcs_rsync(
dirs_to_sync, mp_queue, args.gcs_rsync_interval_secs
)
for cmd in commands:
logging.info('launching task=%s with cmd: \n%s', task, ' \\\n'.join(cmd))
# Both absl logging and python's logging module writes to stderr by default.
# Redirect output to stdout on purpose, such that log entries do not get
# marked as `Error` in Cloud's Log Explorer.
try:
subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stdout, check=True)
except subprocess.CalledProcessError as e:
if rsync_process is not None and rsync_process.is_alive():
logging.info('Terminating GCS rsync process.')
rsync_process.terminate()
raise e
if rsync_process is not None:
gcs_syncer.cleanup_gcs_rsync(rsync_process, mp_queue)