in projects/vision-ai-edge-platform/pipelines/segmentation/deeplabv3plus/trainer/main.py [0:0]
def _get_args():
"""Argument parser.
Returns:
Dictionary of arguments.
"""
cloud_ml_job_id = os.environ['CLOUD_ML_JOB_ID']
parser = argparse.ArgumentParser()
parser.add_argument(
'--experiment',
type=str,
default=f'experiment-{cloud_ml_job_id}',
help='experiment name to log metrics and checkpoints, ' +
'default=experiment-[CLOUD_ML_JOB_ID]')
parser.add_argument(
'--img-width',
type=int,
default=512,
help='input image width assumed by model, default=512')
parser.add_argument(
'--img-height',
type=int,
default=512,
help='input image height assumed by model, default=512')
parser.add_argument(
'--deeplab-preset',
default='efficientnetv2_b0_imagenet',
type=str,
choices=keras_cv.models.DeepLabV3Plus.presets,
help='preset to load backbone with weights from, ' +
'default=efficientnetv2_b0_imagenet')
parser.add_argument(
'--num-epochs',
type=int,
default=100,
help='number of times to go through the data, default=100')
parser.add_argument(
'--batch-size',
default=1,
type=int,
help='number of records to read during each training step, default=1')
parser.add_argument(
'--optimizer',
default='adam',
type=str,
help='optimizer to use for training, default=adam')
parser.add_argument(
'--learning-rate',
default=.001,
type=float,
help='learning rate for optimizer, default=.001')
parser.add_argument(
'--loss-function',
default='dice_focal',
type=str,
choices=SegmentationLosses.__members__.keys(),
help=f'loss function in {SegmentationLosses.__members__.keys()}, ' +
'default=dice_focal')
parser.add_argument(
'--patience-epochs',
type=int,
default=5,
help='number of epochs to wait before early stopping, default=5')
parser.add_argument(
'--checkpoint-frequency',
type=int,
default=1000,
help='number of steps between checkpoints, default=1000')
parser.add_argument(
'--augmentation-factor',
type=int,
default=10,
help='factor by which to increase dataset size with augmentations, ' +
'default=10')
parser.add_argument(
'--verbosity',
choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'],
default='INFO')
return parser.parse_args()