in community-content/vertex_model_garden/model_oss/peft/train/vmg/instruct_lora.py [0:0]
def main(unused_argv: Sequence[str]) -> None:
# This needs to be called before any other PartialState() calls.
utils.init_partial_state(
timeout=datetime.timedelta(seconds=_NCCL_TIMEOUT.value)
)
torch.cuda.set_per_process_memory_fraction(
_MAX_GPU_MEMORY_FRACTION.value, device=PartialState().local_process_index
)
utils.print_library_versions()
warnings.simplefilter(_WARNINGS_FILTER.value)
pretrained_model_name_or_path = fileutils.force_gcs_path(
_PRETRAINED_MODEL_NAME_OR_PATH.value
)
if dataset_validation_util.is_gcs_path(pretrained_model_name_or_path):
pretrained_model_name_or_path = (
dataset_validation_util.download_gcs_uri_to_local(
pretrained_model_name_or_path
)
)
# GCS Fuse does not sync flushed files if not closed. See b/361771727.
logging_output_dir = fileutils.force_gcs_path(_LOGGING_OUTPUT_DIR.value)
# Creates evaluation config.
if _EVAL_DATASET.value:
eval_config = eval_lib.EvalConfig(
per_device_batch_size=_PER_DEVICE_EVAL_BATCH_SIZE.value,
limit=_EVAL_LIMIT.value,
metric_name=_EVAL_METRIC_NAME.value,
steps=_EVAL_STEPS.value,
dataset_path=dataset_validation_util.force_gcs_fuse_path(
_EVAL_DATASET.value
),
split=_EVAL_SPLIT.value,
template=_EVAL_TEMPLATE.value,
column=_EVAL_COLUMN.value,
tokenize_dataset=False,
metric_for_best_model=_METRIC_FOR_BEST_MODEL.value,
)
else:
eval_config = None
if _REPORT_TO.value == constants.REPORT_TO_WANDB:
wandb.login()
finetune_instruct(
pretrained_model_name_or_path=pretrained_model_name_or_path,
train_dataset=_TRAIN_DATASET.value,
output_dir=_OUTPUT_DIR.value,
logging_output_dir=logging_output_dir,
precision_mode=_PRECISION_MODE.value,
lora_rank=_LORA_RANK.value,
lora_alpha=_LORA_ALPHA.value,
lora_dropout=_LORA_DROPOUT.value,
warmup_ratio=_WARMUP_RATIO.value,
num_train_epochs=_NUM_TRAIN_EPOCHS.value,
warmup_steps=_WARMUP_STEPS.value,
max_steps=_MAX_STEPS.value,
max_seq_length=_MAX_SEQ_LENGTH.value,
learning_rate=_LEARNING_RATE.value,
train_column=_TRAIN_COLUMN.value,
per_device_train_batch_size=_PER_DEVICE_TRAIN_BATCH_SIZE.value,
optim=_OPTIMIZER.value,
weight_decay=_WEIGHT_DECAY.value,
gradient_accumulation_steps=_GRADIENT_ACCUMULATION_STEPS.value,
gradient_checkpointing=_GRADIENT_CHECKPOINTING.value,
enable_peft=_ENABLE_PEFT.value,
train_template=_TRAIN_TEMPLATE.value,
lr_scheduler_type=_LR_SCHEDULER_TYPE.value,
save_steps=_SAVE_STEPS.value,
logging_steps=_LOGGING_STEPS.value,
train_split=_TRAIN_SPLIT.value,
eval_config=eval_config,
report_to=_REPORT_TO.value,
access_token=_HUGGINGFACE_ACCESS_TOKEN.value,
train_precision=_TRAIN_PRECISION.value,
example_packing=_EXAMPLE_PACKING.value,
attn_implementation=_ATTN_IMPLEMENTATION.value,
max_grad_norm=_MAX_GRAD_NORM.value,
input_masking=_INPUT_MASKING.value,
logger_level=_LOGGER_LEVEL.value,
benchmark_out_file=_BENCHMARK_OUT_FILE.value,
tuning_data_stats_file=_TUNING_DATA_STATS_FILE.value,
target_modules=_TARGET_MODULES.value,
)
# Frees the model from GPU.
utils.force_gc()