in modeling.py [0:0]
def train_model(model, dataset, optimizer="lbfgs", batch_size=128, num_epochs=100,
learning_rate=1., criterion=None, augmentation=False, momentum=0.9,
use_lr_scheduler=True, visualizer=None, title=None):
"""
Trains `model` on samples from the specified `dataset` using the specified
`optimizer` ("lbfgs" or "sgd") with batch size `batch_size` for `num_epochs`
epochs to minimize the specified `criterion` (default = `nn.CrossEntropyLoss`).
For L-BFGS, the batch size is ignored and full gradients are used. The
`learning_rate` is only used as initial value; step sizes are determined by
checking the Wolfe conditions.
For SGD, the initial learning rate is set to `learning_rate` and is reduced
by a factor of 10 four times during training. Training uses Nesterov momentum
of 0.9. Optionally, data `augmentation` can be enabled as well.
Training progress is shown in the visdom `visualizer` in a window with the
specified `title`.
"""
# set up optimizer, criterion, and learning curve:
model.train()
device = next(model.parameters()).device
if criterion is None:
criterion = nn.CrossEntropyLoss()
if visualizer is not None:
window = [None]
# set up optimizer and learning rate scheduler:
if optimizer == "sgd":
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=momentum)
scheduler = StepLR(optimizer, step_size=max(1, num_epochs // 4), gamma=0.1)
elif optimizer == "lbfgs":
assert not augmentation, "Cannot use data augmentation with L-BFGS."
use_lr_scheduler = False
optimizer = LBFGS(
model.parameters(),
lr=learning_rate,
tolerance_grad=1e-4,
line_search_fn="strong_wolfe",
)
batch_size = len(dataset["targets"])
else:
raise ValueError(f"Unknown optimizer: {optimizer}")
# create data sampler:
transform = dataloading.data_augmentation() if augmentation else None
datasampler = dataloading.load_datasampler(
dataset, batch_size=batch_size, transform=transform
)
# perform training epochs:
for epoch in range(num_epochs):
num_samples, total_loss = 0, 0.
for sample in datasampler():
# copy sample to correct device if needed:
for key in sample.keys():
if sample[key].device != device:
sample[key] = sample[key].to(device=device)
# closure that performs forward-backward pass:
def loss_closure():
optimizer.zero_grad()
predictions = model(sample["features"])
loss = criterion(predictions, sample["targets"])
loss.backward()
return loss
# perform parameter update:
loss = optimizer.step(closure=loss_closure)
# aggregate loss values for monitoring:
total_loss += (loss.item() * sample["features"].size(0))
num_samples += sample["features"].size(0)
# decay learning rate (SGD only):
if use_lr_scheduler and epoch != num_epochs - 1:
scheduler.step()
# print statistics:
if epoch % 10 == 0:
average_loss = total_loss / float(num_samples)
logging.info(f" => epoch {epoch + 1}: loss = {average_loss}")
if visualizer is not None:
window[0] = util.learning_curve(
visualizer,
torch.LongTensor([epoch + 1]),
torch.DoubleTensor([average_loss]),
window=window[0],
title=title,
)
# we are done training:
model.eval()