def main()

in community-content/vertex_model_garden/model_oss/peft/train/vmg/instruct_lora.py [0:0]


def main(unused_argv: Sequence[str]) -> None:
  # This needs to be called before any other PartialState() calls.
  utils.init_partial_state(
      timeout=datetime.timedelta(seconds=_NCCL_TIMEOUT.value)
  )

  torch.cuda.set_per_process_memory_fraction(
      _MAX_GPU_MEMORY_FRACTION.value, device=PartialState().local_process_index
  )

  utils.print_library_versions()
  warnings.simplefilter(_WARNINGS_FILTER.value)

  pretrained_model_name_or_path = fileutils.force_gcs_path(
      _PRETRAINED_MODEL_NAME_OR_PATH.value
  )
  if dataset_validation_util.is_gcs_path(pretrained_model_name_or_path):
    pretrained_model_name_or_path = (
        dataset_validation_util.download_gcs_uri_to_local(
            pretrained_model_name_or_path
        )
    )

  # GCS Fuse does not sync flushed files if not closed. See b/361771727.
  logging_output_dir = fileutils.force_gcs_path(_LOGGING_OUTPUT_DIR.value)

  # Creates evaluation config.
  if _EVAL_DATASET.value:
    eval_config = eval_lib.EvalConfig(
        per_device_batch_size=_PER_DEVICE_EVAL_BATCH_SIZE.value,
        limit=_EVAL_LIMIT.value,
        metric_name=_EVAL_METRIC_NAME.value,
        steps=_EVAL_STEPS.value,
        dataset_path=dataset_validation_util.force_gcs_fuse_path(
            _EVAL_DATASET.value
        ),
        split=_EVAL_SPLIT.value,
        template=_EVAL_TEMPLATE.value,
        column=_EVAL_COLUMN.value,
        tokenize_dataset=False,
        metric_for_best_model=_METRIC_FOR_BEST_MODEL.value,
    )
  else:
    eval_config = None

  if _REPORT_TO.value == constants.REPORT_TO_WANDB:
    wandb.login()

  finetune_instruct(
      pretrained_model_name_or_path=pretrained_model_name_or_path,
      train_dataset=_TRAIN_DATASET.value,
      output_dir=_OUTPUT_DIR.value,
      logging_output_dir=logging_output_dir,
      precision_mode=_PRECISION_MODE.value,
      lora_rank=_LORA_RANK.value,
      lora_alpha=_LORA_ALPHA.value,
      lora_dropout=_LORA_DROPOUT.value,
      warmup_ratio=_WARMUP_RATIO.value,
      num_train_epochs=_NUM_TRAIN_EPOCHS.value,
      warmup_steps=_WARMUP_STEPS.value,
      max_steps=_MAX_STEPS.value,
      max_seq_length=_MAX_SEQ_LENGTH.value,
      learning_rate=_LEARNING_RATE.value,
      train_column=_TRAIN_COLUMN.value,
      per_device_train_batch_size=_PER_DEVICE_TRAIN_BATCH_SIZE.value,
      optim=_OPTIMIZER.value,
      weight_decay=_WEIGHT_DECAY.value,
      gradient_accumulation_steps=_GRADIENT_ACCUMULATION_STEPS.value,
      gradient_checkpointing=_GRADIENT_CHECKPOINTING.value,
      enable_peft=_ENABLE_PEFT.value,
      train_template=_TRAIN_TEMPLATE.value,
      lr_scheduler_type=_LR_SCHEDULER_TYPE.value,
      save_steps=_SAVE_STEPS.value,
      logging_steps=_LOGGING_STEPS.value,
      train_split=_TRAIN_SPLIT.value,
      eval_config=eval_config,
      report_to=_REPORT_TO.value,
      access_token=_HUGGINGFACE_ACCESS_TOKEN.value,
      train_precision=_TRAIN_PRECISION.value,
      example_packing=_EXAMPLE_PACKING.value,
      attn_implementation=_ATTN_IMPLEMENTATION.value,
      max_grad_norm=_MAX_GRAD_NORM.value,
      input_masking=_INPUT_MASKING.value,
      logger_level=_LOGGER_LEVEL.value,
      benchmark_out_file=_BENCHMARK_OUT_FILE.value,
      tuning_data_stats_file=_TUNING_DATA_STATS_FILE.value,
      target_modules=_TARGET_MODULES.value,
  )
  # Frees the model from GPU.
  utils.force_gc()