def add_model_specific_args()

in rag-end2end-retriever/finetune_rag.py [0:0]


    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
        add_generic_args(parser, root_dir)
        parser.add_argument(
            "--max_source_length",
            default=128,
            type=int,
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
        )
        parser.add_argument(
            "--max_target_length",
            default=25,
            type=int,
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
        )
        parser.add_argument(
            "--val_max_target_length",
            default=25,
            type=int,
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
        )
        parser.add_argument(
            "--test_max_target_length",
            default=25,
            type=int,
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
        )
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
        parser.add_argument(
            "--prefix",
            type=str,
            default=None,
            help="Prefix added at the beginning of each text, typically used with T5-based models.",
        )
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help=(
                "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
                " val_check_interval will effect it."
            ),
        )
        parser.add_argument(
            "--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
        )
        parser.add_argument(
            "--model_type",
            choices=["rag_sequence", "rag_token", "bart", "t5"],
            type=str,
            help=(
                "RAG model type: sequence or token, if none specified, the type is inferred from the"
                " model_name_or_path"
            ),
        )
        parser.add_argument(
            "--context_encoder_name",
            default="facebook/dpr-ctx_encoder-multiset-base",
            type=str,
            help="Name of the pre-trained context encoder checkpoint from the DPR",
        )
        parser.add_argument(
            "--csv_path",
            default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset.csv"),
            type=str,
            help="path of the raw KB csv",
        )
        parser.add_argument("--end2end", action="store_true", help="whether to train the system end2end or not")
        parser.add_argument("--index_gpus", type=int, help="how many GPUs used in re-encoding process")
        parser.add_argument(
            "--shard_dir",
            type=str,
            default=str(Path(__file__).parent / "test_run" / "kb-shards"),
            help="directory used to keep temporary shards during the re-encode process",
        )

        parser.add_argument(
            "--gpu_order",
            type=str,
            help=(
                "order of the GPU used during the fine-tuning.  Used to finding free GPUs during the re-encode"
                " process. I do not have many GPUs :)"
            ),
        )

        parser.add_argument("--indexing_freq", type=int, help="frequency of re-encode process")
        return parser