in sagemaker/22_accelerate_sagemaker_examples/src/seq2seq/run_seq2seq_no_trainer.py [0:0]
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a summarization task")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--ignore_pad_token_for_loss",
type=bool,
default=True,
help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
)
parser.add_argument(
"--max_source_length",
type=int,
default=1024,
help=(
"The maximum total input sequence length after "
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--source_prefix",
type=str,
default=None,
help="A prefix to add before every source text (useful for T5 models).",
)
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument(
"--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--max_target_length",
type=int,
default=128,
help=(
"The maximum total sequence length for target text after "
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
"during ``evaluate`` and ``predict``."
),
)
parser.add_argument(
"--val_max_target_length",
type=int,
default=None,
help=(
"The maximum total sequence length for validation "
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
),
)
# New Code #
parser.add_argument(
"--val_min_target_length",
type=int,
default=10,
help=(
"The minimum total sequence length for validation "
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
),
)
parser.add_argument(
"--n_train",
type=int,
default=2000,
help=("Number of training examples to use. If None, all training examples will be used."),
)
parser.add_argument(
"--n_val",
type=int,
default=500,
help=("Number of validation examples to use. If None, all validation examples will be used."),
)
parser.add_argument(
"--n_val_batch_generations",
type=int,
default=5,
help=("Number of validation examples to use. If None, all validation examples will be used."),
)
parser.add_argument(
"--max_length",
type=int,
default=128,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
),
)
parser.add_argument(
"--num_beams",
type=int,
default=None,
help=(
"Number of beams to use for evaluation. This argument will be "
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
),
)
parser.add_argument(
"--pad_to_max_length",
type=bool,
default=False,
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=False,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--text_column",
type=str,
default=None,
help="The name of the column in the datasets containing the full texts (for summarization).",
)
parser.add_argument(
"--summary_column",
type=str,
default=None,
help="The name of the column in the datasets containing the summaries (for summarization).",
)
parser.add_argument(
"--use_slow_tokenizer",
type=bool,
default=False,
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--model_type",
type=str,
default=None,
help="Model type to use if training from scratch.",
choices=MODEL_TYPES,
)
parser.add_argument("--push_to_hub", type=bool, default=False, help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
# New Code #
# Whether to load the best model at the end of training
parser.add_argument(
"--load_best_model",
type=bool,
default=False,
help="Whether to load the best model at the end of training",
)
# New Code #
parser.add_argument(
"--logging_steps",
type=int,
default=None,
help="log every n steps",
)
parser.add_argument(
"--with_tracking",
type=bool,
default=False,
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--report_name",
type=str,
default="chatbot_no_trainer",
help=("The name of the experiment tracking folder. Only applicable when `--with_tracking` is passed."),
)
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
return args