in sagemaker/28_train_llms_with_qlora/scripts/run_clm.py [0:0]
def parse_arge():
"""Parse the arguments."""
parser = argparse.ArgumentParser()
# add model id and dataset path argument
parser.add_argument(
"--model_id",
type=str,
default="google/flan-t5-xl",
help="Model id to use for training.",
)
parser.add_argument("--dataset_path", type=str, default="lm_dataset", help="Path to dataset.")
# add training hyperparameters for epochs, batch size, learning rate, and seed
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for.")
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=1,
help="Batch size to use for training.",
)
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate to use for training.")
parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.")
parser.add_argument(
"--gradient_checkpointing",
type=bool,
default=True,
help="Path to deepspeed config file.",
)
parser.add_argument(
"--bf16",
type=bool,
default=True if torch.cuda.get_device_capability()[0] == 8 else False,
help="Whether to use bf16.",
)
parser.add_argument(
"--merge_weights",
type=bool,
default=True,
help="Whether to merge LoRA weights with base model.",
)
args = parser.parse_known_args()
return args