def get_parser()

in arctic_inference/suffix_decoding/simulator.py [0:0]


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "dataset",
        type=str,
        help="Path to the dataset file",
    )
    parser.add_argument(
        "--format",
        type=str,
        choices=["json", "jsonl", "csv"],
        help="Format of the dataset file, uses its extension if not provided",
    )
    parser.add_argument(
        "--train-dataset",
        type=str,
        help="Path to a separate dataset file for training",
    )
    parser.add_argument(
        "--prompt-column",
        type=str,
        default="prompt",
        help="Column name for the prompts in the dataset",
    )
    parser.add_argument(
        "--response-column",
        type=str,
        default="response",
        help="Column name for the responses in the dataset",
    )
    parser.add_argument(
        "--num-train",
        type=int,
        nargs="+",
        help=("Number of examples to use for training (required if "
              "separate --train-dataset is not provided)"),
    )
    parser.add_argument(
        "--num-eval",
        type=int,
        nargs="+",
        help=("Number of examples to use for evaluation (required if "
              "separate --train-dataset is not provided)"),
    )
    parser.add_argument(
        "--seed",
        type=int,
        nargs="+",
        default=[0],
        help="Random seed (for train/eval split)",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        help="Name of the HuggingFace tokenizer",
    )
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        help="The path to the output CSV file",
    )
    parser.add_argument(
        "-p",
        "--parallel",
        type=int,
        help="Max number of parallel processes",
    )
    parser.add_argument(
        "--max-depth",
        type=int,
        nargs="+",
        default=[64],
        help="Max depth of the suffix tree",
    )
    parser.add_argument(
        "--max-spec-tokens",
        type=int,
        nargs="+",
        default=[0],
        help="Max speculation tokens (if 0, defaults to max_depth)",
    )
    parser.add_argument(
        "--max-spec-factor",
        type=float,
        nargs="+",
        default=[1.0],
        help="Max speculation tokens as a multiplier of the prefix length",
    )
    parser.add_argument(
        "--min-token-prob",
        type=float,
        nargs="+",
        default=[0.1],
        help="Minimum probability of the token to be considered",
    )
    parser.add_argument(
        "--use-tree-spec",
        type=bool_arg,
        nargs="*",
        default=[True],
        help="Whether to use tree-based speculation (True/False)",
    )
    parser.add_argument(
        "--use-cached-prompt",
        type=bool_arg,
        nargs="*",
        default=[True],
        help=("Whether to use the cached prompt for the request in addition "
              "to the global cache of previous responses (True/False)"),
    )
    parser.add_argument(
        "--max-cached-requests",
        type=int,
        nargs="+",
        default=[-1],
        help="Max number of cached requests (if -1, unlimited)",
    )
    parser.add_argument(
        "--evict-fraction",
        type=float,
        nargs="+",
        default=[0.0],
        help="Evict a fraction of the cached sequences before running requests",
    )
    parser.add_argument(
        "--evict-strategy",
        type=str,
        nargs="+",
        choices=["random", "oldest", "newest"],
        default=["random"],
        help="Evict cached sequences based on the specified strategy",
    )
    return parser