in train_weak_to_strong.py [0:0]
def train_model(
model_config: ModelConfig,
train_ds: torch.utils.data.Dataset,
test_ds: torch.utils.data.Dataset,
*,
loss_type: str,
label: str,
subpath,
lr,
eval_batch_size,
epochs=1,
inference_ds: Optional[torch.utils.data.Dataset] = None,
linear_probe: bool = False,
optimizer_name: str = "adam",