in cli/jobs/nebulaml/cifar10_deepspeed/src/cifar10_deepspeed.py [0:0]
def add_argument():
parser = argparse.ArgumentParser(description="CIFAR")
# data
parser.add_argument(
"--with_cuda",
default=False,
action="store_true",
help="use CPU in case there's no GPU support",
)
parser.add_argument(
"--use_ema",
default=False,
action="store_true",
help="whether use exponential moving average",
)
# train
parser.add_argument(
"-b", "--batch_size", default=32, type=int, help="mini-batch size (default: 32)"
)
parser.add_argument(
"-e",
"--epochs",
default=30,
type=int,
help="number of total epochs (default: 30)",
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="local rank passed from distributed launcher",
)
parser.add_argument(
"--log-interval",
type=int,
default=2000,
help="output logging information at a given interval",
)
parser.add_argument(
"--moe",
default=False,
action="store_true",
help="use deepspeed mixture of experts (moe)",
)
parser.add_argument(
"--ep-world-size", default=1, type=int, help="(moe) expert parallel world size"
)
parser.add_argument(
"--num-experts",
type=int,
nargs="+",
default=[
1,
],
help="number of experts list, MoE related.",
)
parser.add_argument(
"--mlp-type",
type=str,
default="standard",
help="Only applicable when num-experts > 1, accepts [standard, residual]",
)
parser.add_argument(
"--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported"
)
parser.add_argument(
"--min-capacity",
default=0,
type=int,
help="(moe) minimum capacity of an expert regardless of the capacity_factor",
)
parser.add_argument(
"--noisy-gate-policy",
default=None,
type=str,
help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
)
parser.add_argument(
"--moe-param-group",
default=False,
action="store_true",
help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
)
parser.add_argument("--global_rank", default=-1, type=int, help="global rank")
parser.add_argument(
"--with_aml_log", default=False, help="Use Azure ML metric logging"
)
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args