def get_arg_parser()

in xformers/benchmarks/LRA/run_tasks.py [0:0]


def get_arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--attention",
        type=str,
        help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \
            A list can be passed to test several mechanisms in sequence",
        dest="attention",
        required=True,
    )
    parser.add_argument(
        "--task",
        type=Task,
        help=f"Task to chose, among {[t.value for t in Task]}.",
        dest="task",
        required=True,
    )
    parser.add_argument(
        "--skip_train",
        type=bool,
        help="Whether to skip training, and test an existing model",
        dest="skip_train",
        default=False,
    )
    parser.add_argument(
        "--config",
        type=str,
        help="Path to the config being used",
        dest="config",
        default="./config.json",
    )
    parser.add_argument(
        "--checkpoint_dir",
        type=str,
        help="Path to the checkpoint directory",
        dest="checkpoint_dir",
        default=f"/checkpoints/{os.getenv('USER')}/xformers",
    )
    parser.add_argument(
        "--debug",
        help="Make it easier to debug a possible issue",
        dest="debug",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--world_size",
        help="Number of GPUs used",
        dest="world_size",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--sweep_parameters",
        help="Rewrite some hyperparameters in the config",
        dest="sweep_parameters",
        type=dict,
        default=None,
    )
    parser.add_argument(
        "--tb_dir",
        type=str,
        help="Path to the tensorboard directory",
        dest="tb_dir",
        default=f"/checkpoints/{os.getenv('USER')}/xformers/tb",
    )
    return parser