in rag-end2end-retriever/finetune_rag.py [0:0]
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=128,
type=int,
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument(
"--prefix",
type=str,
default=None,
help="Prefix added at the beginning of each text, typically used with T5-based models.",
)
parser.add_argument(
"--early_stopping_patience",
type=int,
default=-1,
required=False,
help=(
"-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
" val_check_interval will effect it."
),
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
)
parser.add_argument(
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
help=(
"RAG model type: sequence or token, if none specified, the type is inferred from the"
" model_name_or_path"
),
)
parser.add_argument(
"--context_encoder_name",
default="facebook/dpr-ctx_encoder-multiset-base",
type=str,
help="Name of the pre-trained context encoder checkpoint from the DPR",
)
parser.add_argument(
"--csv_path",
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset.csv"),
type=str,
help="path of the raw KB csv",
)
parser.add_argument("--end2end", action="store_true", help="whether to train the system end2end or not")
parser.add_argument("--index_gpus", type=int, help="how many GPUs used in re-encoding process")
parser.add_argument(
"--shard_dir",
type=str,
default=str(Path(__file__).parent / "test_run" / "kb-shards"),
help="directory used to keep temporary shards during the re-encode process",
)
parser.add_argument(
"--gpu_order",
type=str,
help=(
"order of the GPU used during the fine-tuning. Used to finding free GPUs during the re-encode"
" process. I do not have many GPUs :)"
),
)
parser.add_argument("--indexing_freq", type=int, help="frequency of re-encode process")
return parser