in xformers/benchmarks/LRA/run_tasks.py [0:0]
def get_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--attention",
type=str,
help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \
A list can be passed to test several mechanisms in sequence",
dest="attention",
required=True,
)
parser.add_argument(
"--task",
type=Task,
help=f"Task to chose, among {[t.value for t in Task]}.",
dest="task",
required=True,
)
parser.add_argument(
"--skip_train",
type=bool,
help="Whether to skip training, and test an existing model",
dest="skip_train",
default=False,
)
parser.add_argument(
"--config",
type=str,
help="Path to the config being used",
dest="config",
default="./config.json",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
help="Path to the checkpoint directory",
dest="checkpoint_dir",
default=f"/checkpoints/{os.getenv('USER')}/xformers",
)
parser.add_argument(
"--debug",
help="Make it easier to debug a possible issue",
dest="debug",
default=False,
action="store_true",
)
parser.add_argument(
"--world_size",
help="Number of GPUs used",
dest="world_size",
type=int,
default=1,
)
parser.add_argument(
"--sweep_parameters",
help="Rewrite some hyperparameters in the config",
dest="sweep_parameters",
type=dict,
default=None,
)
parser.add_argument(
"--tb_dir",
type=str,
help="Path to the tensorboard directory",
dest="tb_dir",
default=f"/checkpoints/{os.getenv('USER')}/xformers/tb",
)
return parser