def fit_model()

in torchmoji/finetuning.py [0:0]


def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
              checkpoint_path, patience):
    """ Analog to Keras fit_generator function.

    # Arguments:
        model: Model to be finetuned.
        loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.)
        optim_op: optimization operation (Adam e.g.)
        train_gen: Training data iterator (DataLoader)
        val_gen: Validation data iterator (DataLoader)
        epochs: Number of epochs.
        checkpoint_path: Filepath where weights will be checkpointed to
            during training. This file will be rewritten by the function.
        patience: Patience for callback methods.
        verbose: Verbosity flag.

    # Returns:
        Accuracy of the trained model, ONLY if 'evaluate' is set.
    """
    # Save original checkpoint
    torch.save(model.state_dict(), checkpoint_path)

    model.eval()
    best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
    print("original val loss", best_loss)

    epoch_without_impr = 0
    for epoch in range(epochs):
        for i, data in enumerate(train_gen):
            X_train, y_train = data
            X_train = Variable(X_train, requires_grad=False)
            y_train = Variable(y_train, requires_grad=False)
            model.train()
            optim_op.zero_grad()
            output = model(X_train)
            loss = calc_loss(loss_op, output, y_train)
            loss.backward()
            clip_grad_norm(model.parameters(), 1)
            optim_op.step()

            acc = evaluate_using_acc(model, [(X_train.data, y_train.data)])
            print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc)

        model.eval()
        acc = evaluate_using_acc(model, val_gen)
        print("val acc", acc)

        val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
        print("val loss", val_loss)
        if best_loss is not None and val_loss >= best_loss:
            epoch_without_impr += 1
            print('No improvement over previous best loss: ', best_loss)

        # Save checkpoint
        if best_loss is None or val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), checkpoint_path)
            print('Saving model at', checkpoint_path)

        # Early stopping
        if epoch_without_impr >= patience:
            break