in weak_to_strong/train.py [0:0]
def train_and_save_model(
model_config: ModelConfig,
train_ds: datasets.Dataset,
test_ds: datasets.Dataset,
inference_ds: Optional[datasets.Dataset] = None,
*,
batch_size: int,
lr: float,
epochs: int,
eval_batch_size: Optional[int] = None,
minibatch_size_per_device: Optional[int] = None,
save_path: Optional[str] = None,
loss_fn: Callable = xent_loss,
label: str = "default",
force_retrain: bool = False,
train_with_dropout: bool = False,
linear_probe: bool = False,
lr_schedule: str = "constant",
optimizer_name: str = "adam",
eval_every: Optional[int] = None,