in torchmoji/finetuning.py [0:0]
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
checkpoint_path, patience):
""" Analog to Keras fit_generator function.
# Arguments:
model: Model to be finetuned.
loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.)
optim_op: optimization operation (Adam e.g.)
train_gen: Training data iterator (DataLoader)
val_gen: Validation data iterator (DataLoader)
epochs: Number of epochs.
checkpoint_path: Filepath where weights will be checkpointed to
during training. This file will be rewritten by the function.
patience: Patience for callback methods.
verbose: Verbosity flag.
# Returns:
Accuracy of the trained model, ONLY if 'evaluate' is set.
"""
# Save original checkpoint
torch.save(model.state_dict(), checkpoint_path)
model.eval()
best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
print("original val loss", best_loss)
epoch_without_impr = 0
for epoch in range(epochs):
for i, data in enumerate(train_gen):
X_train, y_train = data
X_train = Variable(X_train, requires_grad=False)
y_train = Variable(y_train, requires_grad=False)
model.train()
optim_op.zero_grad()
output = model(X_train)
loss = calc_loss(loss_op, output, y_train)
loss.backward()
clip_grad_norm(model.parameters(), 1)
optim_op.step()
acc = evaluate_using_acc(model, [(X_train.data, y_train.data)])
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc)
model.eval()
acc = evaluate_using_acc(model, val_gen)
print("val acc", acc)
val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
print("val loss", val_loss)
if best_loss is not None and val_loss >= best_loss:
epoch_without_impr += 1
print('No improvement over previous best loss: ', best_loss)
# Save checkpoint
if best_loss is None or val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), checkpoint_path)
print('Saving model at', checkpoint_path)
# Early stopping
if epoch_without_impr >= patience:
break