def tune_trainable()

in torchmoji/finetuning.py [0:0]


def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen,
                   nb_epochs, checkpoint_path, patience=5, evaluate='acc',
                   verbose=2):
    """ Finetunes the given model using the accuracy measure.

    # Arguments:
        model: Model to be finetuned.
        nb_classes: Number of classes in the given dataset.
        train: Training data, given as a tuple of (inputs, outputs)
        val: Validation data, given as a tuple of (inputs, outputs)
        test: Testing data, given as a tuple of (inputs, outputs)
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs.
        batch_size: Batch size.
        checkpoint_weight_path: Filepath where weights will be checkpointed to
            during training. This file will be rewritten by the function.
        patience: Patience for callback methods.
        evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
        verbose: Verbosity flag.

    # Returns:
        Accuracy of the trained model, ONLY if 'evaluate' is set.
    """
    if verbose:
        print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad]))
        print("Training...")
        if evaluate == 'acc':
            print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen))
        elif evaluate == 'weighted_f1':
            print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen))

    fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience)

    # Reload the best weights found to avoid overfitting
    # Wait a bit to allow proper closing of weights file
    sleep(1)
    model.load_state_dict(torch.load(checkpoint_path))
    if verbose >= 2:
        print("Loaded weights from {}".format(checkpoint_path))

    if evaluate == 'acc':
        return evaluate_using_acc(model, test_gen)
    elif evaluate == 'weighted_f1':
        return evaluate_using_weighted_f1(model, test_gen, val_gen)