def class_avg_tune_trainable()

in torchmoji/class_avg_finetuning.py [0:0]


def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test,
                             epoch_size, nb_epochs, batch_size,
                             init_weight_path, checkpoint_weight_path, patience=5,
                             verbose=True):
    """ Finetunes the given model using the F1 measure.

    # Arguments:
        model: Model to be finetuned.
        nb_classes: Number of classes in the given dataset.
        train: Training data, given as a tuple of (inputs, outputs)
        val: Validation data, given as a tuple of (inputs, outputs)
        test: Testing data, given as a tuple of (inputs, outputs)
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs.
        batch_size: Batch size.
        init_weight_path: Filepath where weights will be initially saved before
            training each class. This file will be rewritten by the function.
        checkpoint_weight_path: Filepath where weights will be checkpointed to
            during training. This file will be rewritten by the function.
        verbose: Verbosity flag.

    # Returns:
        F1 score of the trained model
    """
    total_f1 = 0
    nb_iter = nb_classes if nb_classes > 2 else 1

    # Unpack args
    X_train, y_train = train
    X_val, y_val = val
    X_test, y_test = test

    # Save and reload initial weights after running for
    # each class to avoid learning across classes
    torch.save(model.state_dict(), init_weight_path)
    for i in range(nb_iter):
        if verbose:
            print('Iteration number {}/{}'.format(i+1, nb_iter))

        model.load_state_dict(torch.load(init_weight_path))
        y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
                                                            y_test, i, nb_classes)
        train_gen, X_val_resamp, y_val_resamp = \
            prepare_generators(X_train, y_train_new, X_val, y_val_new,
                               batch_size, epoch_size)

        if verbose:
            print("Training..")
        fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)],
                  nb_epochs, checkpoint_weight_path, patience, verbose=0)

        # Reload the best weights found to avoid overfitting
        # Wait a bit to allow proper closing of weights file
        sleep(1)
        model.load_state_dict(torch.load(checkpoint_weight_path))

        # Evaluate
        y_pred_val = model(X_val).cpu().numpy()
        y_pred_test = model(X_test).cpu().numpy()

        f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
                                            y_test_new, y_pred_test)
        if verbose:
            print('f1_test: {}'.format(f1_test))
            print('best_t:  {}'.format(best_t))
        total_f1 += f1_test

    return total_f1 / nb_iter