def build_arguments_parser()

in tutorials/e2e-distributed-pytorch-image/src/pytorch_dl_train/train.py [0:0]


def build_arguments_parser(parser: argparse.ArgumentParser = None):
    """Builds the argument parser for CLI settings"""
    if parser is None:
        parser = argparse.ArgumentParser()

    group = parser.add_argument_group(f"Training Inputs")
    group.add_argument(
        "--train_images",
        type=str,
        required=True,
        help="Path to folder containing training images",
    )
    group.add_argument(
        "--valid_images",
        type=str,
        required=True,
        help="path to folder containing validation images",
    )

    group = parser.add_argument_group(f"Training Outputs")
    group.add_argument(
        "--model_output",
        type=str,
        required=False,
        default=None,
        help="Path to write final model",
    )
    group.add_argument(
        "--register_model_as",
        type=str,
        required=False,
        default=None,
        help="Name to register final model in MLFlow",
    )

    group = parser.add_argument_group(f"Data Loading Parameters")
    group.add_argument(
        "--batch_size",
        type=int,
        required=False,
        default=64,
        help="Train/valid data loading batch size (default: 64)",
    )
    group.add_argument(
        "--num_workers",
        type=int,
        required=False,
        default=None,
        help="Num workers for data loader (default: -1 => all cpus available)",
    )
    group.add_argument(
        "--prefetch_factor",
        type=int,
        required=False,
        default=2,
        help="Data loader prefetch factor (default: 2)",
    )
    group.add_argument(
        "--pin_memory",
        type=strtobool,
        required=False,
        default=True,
        help="Pin Data loader prefetch factor (default: True)",
    )
    group.add_argument(
        "--non_blocking",
        type=strtobool,
        required=False,
        default=False,
        help="Use non-blocking transfer to device (default: False)",
    )

    group = parser.add_argument_group(f"Model/Training Parameters")
    group.add_argument(
        "--model_arch",
        type=str,
        required=False,
        choices=MODEL_ARCH_LIST,
        default="resnet18",
        help="Which model architecture to use (default: resnet18)",
    )
    group.add_argument(
        "--model_arch_pretrained",
        type=strtobool,
        required=False,
        default=True,
        help="Use pretrained model (default: true)",
    )
    group.add_argument(
        "--distributed_backend",
        type=str,
        required=False,
        choices=["nccl", "mpi"],
        default="nccl",
        help="Which distributed backend to use.",
    )
    # DISTRIBUTED: torch.distributed.launch is passing this argument to your script
    # it is likely to be deprecated in favor of os.environ['LOCAL_RANK']
    # see https://pytorch.org/docs/stable/distributed.html#launch-utility
    group.add_argument(
        "--local_rank",
        type=int,
        required=False,
        default=None,
        help="Passed by torch.distributed.launch utility when running from cli.",
    )
    group.add_argument(
        "--num_epochs",
        type=int,
        required=False,
        default=1,
        help="Number of epochs to train for",
    )
    group.add_argument(
        "--learning_rate",
        type=float,
        required=False,
        default=0.01,
        help="Learning rate of optimizer",
    )
    group.add_argument(
        "--momentum",
        type=float,
        required=False,
        default=0.01,
        help="Momentum of optimizer",
    )

    group = parser.add_argument_group(f"Monitoring/Profiling Parameters")
    group.add_argument(
        "--enable_profiling",
        type=strtobool,
        required=False,
        default=False,
        help="Enable pytorch profiler.",
    )

    return parser