in tutorials/e2e-distributed-pytorch-image/src/pytorch_dl_train/train.py [0:0]
def build_arguments_parser(parser: argparse.ArgumentParser = None):
"""Builds the argument parser for CLI settings"""
if parser is None:
parser = argparse.ArgumentParser()
group = parser.add_argument_group(f"Training Inputs")
group.add_argument(
"--train_images",
type=str,
required=True,
help="Path to folder containing training images",
)
group.add_argument(
"--valid_images",
type=str,
required=True,
help="path to folder containing validation images",
)
group = parser.add_argument_group(f"Training Outputs")
group.add_argument(
"--model_output",
type=str,
required=False,
default=None,
help="Path to write final model",
)
group.add_argument(
"--register_model_as",
type=str,
required=False,
default=None,
help="Name to register final model in MLFlow",
)
group = parser.add_argument_group(f"Data Loading Parameters")
group.add_argument(
"--batch_size",
type=int,
required=False,
default=64,
help="Train/valid data loading batch size (default: 64)",
)
group.add_argument(
"--num_workers",
type=int,
required=False,
default=None,
help="Num workers for data loader (default: -1 => all cpus available)",
)
group.add_argument(
"--prefetch_factor",
type=int,
required=False,
default=2,
help="Data loader prefetch factor (default: 2)",
)
group.add_argument(
"--pin_memory",
type=strtobool,
required=False,
default=True,
help="Pin Data loader prefetch factor (default: True)",
)
group.add_argument(
"--non_blocking",
type=strtobool,
required=False,
default=False,
help="Use non-blocking transfer to device (default: False)",
)
group = parser.add_argument_group(f"Model/Training Parameters")
group.add_argument(
"--model_arch",
type=str,
required=False,
choices=MODEL_ARCH_LIST,
default="resnet18",
help="Which model architecture to use (default: resnet18)",
)
group.add_argument(
"--model_arch_pretrained",
type=strtobool,
required=False,
default=True,
help="Use pretrained model (default: true)",
)
group.add_argument(
"--distributed_backend",
type=str,
required=False,
choices=["nccl", "mpi"],
default="nccl",
help="Which distributed backend to use.",
)
# DISTRIBUTED: torch.distributed.launch is passing this argument to your script
# it is likely to be deprecated in favor of os.environ['LOCAL_RANK']
# see https://pytorch.org/docs/stable/distributed.html#launch-utility
group.add_argument(
"--local_rank",
type=int,
required=False,
default=None,
help="Passed by torch.distributed.launch utility when running from cli.",
)
group.add_argument(
"--num_epochs",
type=int,
required=False,
default=1,
help="Number of epochs to train for",
)
group.add_argument(
"--learning_rate",
type=float,
required=False,
default=0.01,
help="Learning rate of optimizer",
)
group.add_argument(
"--momentum",
type=float,
required=False,
default=0.01,
help="Momentum of optimizer",
)
group = parser.add_argument_group(f"Monitoring/Profiling Parameters")
group.add_argument(
"--enable_profiling",
type=strtobool,
required=False,
default=False,
help="Enable pytorch profiler.",
)
return parser