in weak_to_strong/train.py [0:0]
def train_model(
model: torch.nn.Module,
ds: datasets.Dataset,
batch_size: int,
lr: float = 1e-5,
loss_fn: Callable = xent_loss,
log_every: int = 10,
eval_every: int = 100,
eval_batch_size: int = 256,
minibatch_size: int = 8,
eval_ds: Optional[datasets.Dataset] = None,
gradient_checkpointing: bool = False,
train_with_dropout: bool = False,
epochs: int = 1,
lr_schedule: str = "cosine_anneal",
optimizer_name: str = "adam",