in distributed_training/src_dir/main_trainer.py [0:0]
def args_fn():
parser = argparse.ArgumentParser(description='PyTorch Resnet50 Example')
# Default Setting
parser.add_argument(
'--log-interval',
type=int,
default=5,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument(
'--backend',
type=str,
default='nccl',
help=
'backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)'
)
parser.add_argument('--channels-last', type=bool, default=True)
parser.add_argument('--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument('-p',
'--print-freq',
default=10,
type=int,
metavar='N',
help='print frequency (default: 10)')
# Hyperparameter Setting
parser.add_argument('--model_name', type=str, default='resnet50')
parser.add_argument('--height', type=int, default=224)
parser.add_argument('--width', type=int, default=224)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--num-epochs', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--test-batch-size',
type=int,
default=200,
metavar='N',
help='input batch size for testing (default: 200)')
# Setting for Distributed Training
parser.add_argument('--data_parallel', type=bool, default=False)
parser.add_argument('--model_parallel', type=bool, default=False)
parser.add_argument('--apex', type=bool, default=False)
parser.add_argument('--opt-level', type=str, default='O0')
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--sync_bn',
action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--prof',
default=-1,
type=int,
help='Only run 10 iterations for profiling.')
# Setting for Model Parallel
parser.add_argument("--horovod", type=int, default=0)
parser.add_argument('--mp_parameters', type=str, default='')
parser.add_argument("--ddp", type=int, default=0)
parser.add_argument("--amp", type=int, default=0)
parser.add_argument("--save_full_model", type=bool, default=True)
parser.add_argument("--pipeline", type=str, default="interleaved")
parser.add_argument("--assert-losses", type=int, default=0)
parser.add_argument("--partial-checkpoint",
type=str,
default="",
help="The checkpoint path to load")
parser.add_argument("--full-checkpoint",
type=str,
default="",
help="The checkpoint path to load")
parser.add_argument("--save-full-model",
action="store_true",
default=False,
help="For Saving the current Model")
parser.add_argument(
"--save-partial-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
# SageMaker Container environment
parser.add_argument('--hosts',
type=list,
default=json.loads(os.environ['SM_HOSTS']))
parser.add_argument('--current-host',
type=str,
default=os.environ['SM_CURRENT_HOST'])
parser.add_argument('--model-dir',
type=str,
default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--data-dir',
type=str,
default=os.environ['SM_CHANNEL_TRAINING'])
parser.add_argument('--num-gpus',
type=int,
default=os.environ['SM_NUM_GPUS'])
parser.add_argument('--output_data_dir',
type=str,
default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--rank', type=int, default=0)
args = parser.parse_args()
return args