def main()

in community-content/vertex_model_garden/model_oss/peft/train/vmg/train_entrypoint.py [0:0]


def main(unused_argv: Sequence[str]) -> None:
  parser = argparse.ArgumentParser()
  parser.add_argument('--config_file')
  parser.add_argument('--task')
  parser.add_argument('--gcs_rsync_interval_secs', type=int, default=60)
  args, unknown = parser.parse_known_args()

  task = args.task
  dirs_to_sync = None

  if task in _TEXT_TO_IMAGE_TASKS_SCRIPTS:
    # Setup accelerate config before running trainer.
    config_gen_cmd = [
        'python',
        '-c',
        (
            'from accelerate.utils import write_basic_config;'
            ' write_basic_config(mixed_precision="fp16")'
        ),
    ]
    task_cmd = [
        'accelerate',
        'launch',
        _TEXT_TO_IMAGE_TASKS_SCRIPTS[task],
    ] + list(map(dataset_validation_util.force_gcs_fuse_path, unknown))
    commands = [config_gen_cmd, task_cmd]
  elif task in [constants.INSTRUCT_LORA]:
    commands, dirs_to_sync = _get_train_and_maybe_merge_cmd_and_dirs_to_sync(
        task_type=task, config_file=args.config_file, unknown=unknown
    )
  elif task in [constants.MERGE_CAUSAL_LANGUAGE_MODEL_LORA]:
    commands, dirs_to_sync = _get_merge_cmd_and_dirs_to_sync(
        task_type=task, config_file=args.config_file, unknown=unknown
    )
  else:
    assert task in _TASK_TO_SCRIPT
    cmd = launch_script_cmd(_TASK_TO_SCRIPT[task], args.config_file)
    cmd.extend(unknown)
    commands = [cmd]

  rsync_process = None
  mp_queue = multiprocessing.Queue(maxsize=1)
  if dirs_to_sync:
    rsync_process = gcs_syncer.setup_gcs_rsync(
        dirs_to_sync, mp_queue, args.gcs_rsync_interval_secs
    )

  for cmd in commands:
    logging.info('launching task=%s with cmd: \n%s', task, ' \\\n'.join(cmd))
    # Both absl logging and python's logging module writes to stderr by default.
    # Redirect output to stdout on purpose, such that log entries do not get
    # marked as `Error` in Cloud's Log Explorer.
    try:
      subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stdout, check=True)
    except subprocess.CalledProcessError as e:
      if rsync_process is not None and rsync_process.is_alive():
        logging.info('Terminating GCS rsync process.')
        rsync_process.terminate()
      raise e
  if rsync_process is not None:
    gcs_syncer.cleanup_gcs_rsync(rsync_process, mp_queue)