def args_fn()

in distributed_training/src_dir/main_trainer.py [0:0]


def args_fn():
    parser = argparse.ArgumentParser(description='PyTorch Resnet50 Example')

    # Default Setting
    parser.add_argument(
        '--log-interval',
        type=int,
        default=5,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--backend',
        type=str,
        default='nccl',
        help=
        'backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)'
    )
    parser.add_argument('--channels-last', type=bool, default=True)
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('-p',
                        '--print-freq',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')

    # Hyperparameter Setting
    parser.add_argument('--model_name', type=str, default='resnet50')
    parser.add_argument('--height', type=int, default=224)
    parser.add_argument('--width', type=int, default=224)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--num-classes', type=int, default=10)
    parser.add_argument('--num-epochs', type=int, default=3)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=200,
                        metavar='N',
                        help='input batch size for testing (default: 200)')

    # Setting for Distributed Training
    parser.add_argument('--data_parallel', type=bool, default=False)
    parser.add_argument('--model_parallel', type=bool, default=False)
    parser.add_argument('--apex', type=bool, default=False)
    parser.add_argument('--opt-level', type=str, default='O0')
    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
    parser.add_argument('--loss-scale', type=str, default=None)
    parser.add_argument('--sync_bn',
                        action='store_true',
                        help='enabling apex sync BN.')
    parser.add_argument('--prof',
                        default=-1,
                        type=int,
                        help='Only run 10 iterations for profiling.')

    # Setting for Model Parallel
    parser.add_argument("--horovod", type=int, default=0)
    parser.add_argument('--mp_parameters', type=str, default='')
    parser.add_argument("--ddp", type=int, default=0)
    parser.add_argument("--amp", type=int, default=0)
    parser.add_argument("--save_full_model", type=bool, default=True)
    parser.add_argument("--pipeline", type=str, default="interleaved")
    parser.add_argument("--assert-losses", type=int, default=0)
    parser.add_argument("--partial-checkpoint",
                        type=str,
                        default="",
                        help="The checkpoint path to load")
    parser.add_argument("--full-checkpoint",
                        type=str,
                        default="",
                        help="The checkpoint path to load")
    parser.add_argument("--save-full-model",
                        action="store_true",
                        default=False,
                        help="For Saving the current Model")
    parser.add_argument(
        "--save-partial-model",
        action="store_true",
        default=False,
        help="For Saving the current Model",
    )

    # SageMaker Container environment
    parser.add_argument('--hosts',
                        type=list,
                        default=json.loads(os.environ['SM_HOSTS']))
    parser.add_argument('--current-host',
                        type=str,
                        default=os.environ['SM_CURRENT_HOST'])
    parser.add_argument('--model-dir',
                        type=str,
                        default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--data-dir',
                        type=str,
                        default=os.environ['SM_CHANNEL_TRAINING'])
    parser.add_argument('--num-gpus',
                        type=int,
                        default=os.environ['SM_NUM_GPUS'])
    parser.add_argument('--output_data_dir',
                        type=str,
                        default=os.environ['SM_OUTPUT_DATA_DIR'])
    parser.add_argument('--rank', type=int, default=0)

    args = parser.parse_args()
    return args