in train_weak_to_strong.py [0:0]
def main(
batch_size: int = 32,
max_ctx: int = 1024,
ds_name: str = "sciq",
transfer_loss: Union[str, Sequence[str]] = "xent,logconf",
n_docs: int = 10000,
n_test_docs: int = 200,
weak_model_size: str = "gpt2",
weak_lr: Optional[float] = None,
strong_model_size: str = "gpt2-xl",
strong_lr: Optional[float] = None,
# Defaults to strong_lr
transfer_lr: Optional[float] = None,
# Optims default to default_optimizer in the model definitions
weak_optim: Optional[str] = None,
strong_optim: Optional[str] = None,
transfer_optim: Optional[str] = None,
gt_epochs: int = 2,
# defaults to gt_epochs
transfer_epochs: Optional[int] = None,
force_retrain: bool = False,
seed: int = 0,
minibatch_size_per_device: Optional[int] = None,
train_with_dropout: bool = False,
results_folder: str = "/tmp/results",
linear_probe: bool = False,
lr_schedule: str = "cosine_anneal",
log_prefix: str = "",
# Set to an absurdly high value so we don't do intermediate evals by default.
eval_every: int = 100000000,