in grok/training.py [0:0]
def add_model_specific_args(parser: ArgumentParser) -> ArgumentParser:
"""
Defines the hyperparameter arguments needed by instances of this
class. This is intended to be called when parsing command line
arguments.
:param parser: an argparse.ArgumentParser created by the caller
:returns: the argument parser with the command line arguments added
for this class.
"""
parser.add_argument(
"--batchsize",
type=float,
# default=0.25,
default=0,
help="-1 -> entire dataset, 0 -> auto-calculate, 0<N<1 -> fraction of dataset, N>1 -> N",
)
parser.add_argument("--n_layers", type=int, default=2)
parser.add_argument("--n_heads", type=int, default=4)
parser.add_argument("--d_model", type=int, default=128)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--weight_noise", type=float, default=0.0)
parser.add_argument("--non_linearity", type=str, default="relu")
parser.add_argument("--max_context_len", type=int, default=50)
parser.add_argument("--math_operator", type=str, default="+")
parser.add_argument(
"--operand_length",
type=int,
help="for list operations, the length of the lists",
)
parser.add_argument("--train_data_pct", type=float, default=5)
parser.add_argument("--warmup_steps", type=int, default=10)
parser.add_argument("--anneal_lr_steps", type=int, default=100000)
parser.add_argument("--anneal_lr", dest="anneal_lr", action="store_true")
parser.set_defaults(anneal_lr=False)
parser.add_argument("--max_lr", type=float, default=1e-3)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--weight_decay_kind", type=str, default="to_zero")
parser.add_argument("--noise_factor", type=float, default=0)
parser.add_argument(
"--save_activations", dest="save_activations", action="store_true"
)
parser.set_defaults(save_activations=False)
parser.add_argument("--save_outputs", dest="save_outputs", action="store_true")
parser.set_defaults(save_outputs=False)
parser.add_argument(
"--logdir",
type=str,
default=DEFAULT_LOG_DIR,
)
parser.add_argument(
"--datadir",
type=str,
default=DEFAULT_DATA_DIR,
)
return parser