in cli/jobs/pipelines/tensorflow-image-segmentation/src/run.py [0:0]
def build_arguments_parser(parser: argparse.ArgumentParser = None):
"""Builds the argument parser for CLI settings"""
if parser is None:
parser = argparse.ArgumentParser()
group = parser.add_argument_group("Training Inputs")
group.add_argument(
"--train_images",
type=str,
required=True,
help="Path to folder containing training images",
)
group.add_argument(
"--images_filename_pattern",
type=str,
required=False,
default="(.*)\\.jpg",
help="Regex used to find and match images with masks (matched on group(1))",
)
group.add_argument(
"--images_type",
type=str,
required=False,
choices=["png", "jpg"],
default="png",
help="png (default) or jpg",
)
group.add_argument(
"--train_masks",
type=str,
required=True,
help="path to folder containing segmentation masks",
)
group.add_argument(
"--masks_filename_pattern",
type=str,
required=False,
default="(.*)\\.png",
help="Regex used to find and match images with masks (matched on group(1))",
)
group.add_argument(
"--test_images",
type=str,
required=True,
help="Path to folder containing testing images",
)
group.add_argument(
"--test_masks",
type=str,
required=True,
help="path to folder containing segmentation masks",
)
group = parser.add_argument_group("Training Outputs")
group.add_argument(
"--model_output",
type=str,
required=False,
default=None,
help="Path to write final model",
)
group.add_argument(
"--checkpoints",
type=str,
required=False,
default=None,
help="Path to read/write checkpoints",
)
group.add_argument(
"--register_model_as",
type=str,
required=False,
default=None,
help="Name to register final model in MLFlow",
)
group = parser.add_argument_group("Data Loading Parameters")
group.add_argument(
"--batch_size",
type=int,
required=False,
default=64,
help="Train/valid data loading batch size (default: 64)",
)
group.add_argument(
"--num_workers",
type=int,
required=False,
default=-1,
help="Num workers for data loader (default: AUTOTUNE)",
)
group.add_argument(
"--prefetch_factor",
type=int,
required=False,
default=-1,
help="Data loader prefetch factor (default: AUTOTUNE)",
)
group.add_argument(
"--cache",
type=str,
required=False,
choices=["none", "disk", "memory"],
default="none",
help="Use cache either on DISK or in MEMORY, or NONE",
)
group = parser.add_argument_group("Model/Training Parameters")
group.add_argument(
"--model_arch",
type=str,
required=False,
default="unet",
help="Which model architecture to use (default: unet)",
)
group.add_argument(
"--model_input_size",
type=int,
required=True,
help="Size of input images (resized)",
)
group.add_argument(
"--num_classes", type=int, required=True, help="Number of classes"
)
group.add_argument(
"--num_epochs",
type=int,
required=False,
default=1,
help="Number of epochs to train for",
)
group.add_argument(
"--optimizer",
type=str,
required=False,
default="rmsprop",
)
group.add_argument(
"--loss",
type=str,
required=False,
default="sparse_categorical_crossentropy",
)
# group.add_argument(
# "--learning_rate",
# type=float,
# required=False,
# default=0.001,
# help="Learning rate of optimizer",
# )
group = parser.add_argument_group("Training Backend Parameters")
group.add_argument(
"--enable_profiling",
type=strtobool,
required=False,
default=False,
help="Enable tensorflow profiler.",
)
group.add_argument(
"--disable_cuda",
type=strtobool,
required=False,
default=False,
help="set True to force use of cpu (local testing).",
)
group.add_argument(
"--num_gpus",
type=int,
required=False,
default=-1,
help="limit the number of gpus to use (default: -1 for no limit).",
)
group.add_argument(
"--distributed_strategy",
type=str,
required=False,
# see https://www.tensorflow.org/guide/distributed_training
choices=[
"auto",
"multiworkermirroredstrategy",
"mirroredstrategy",
"onedevicestrategy",
"horovod",
],
default="auto", # will auto identify
help="Which distributed strategy to use.",
)
group.add_argument(
"--distributed_backend",
type=str,
required=False,
choices=[
"auto",
"ring",
"nccl",
],
default="Auto", # will auto identify
help="Which backend (ring, nccl, auto) for MultiWorkerMirroredStrategy collective communication.",
)
return parser