def chain_thaw()

in torchmoji/finetuning.py [0:0]


def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
               patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1):
    """ Finetunes given model using chain-thaw and evaluates using accuracy.

    # Arguments:
        model: Model to be finetuned.
        train: Training data, given as a tuple of (inputs, outputs)
        val: Validation data, given as a tuple of (inputs, outputs)
        test: Testing data, given as a tuple of (inputs, outputs)
        batch_size: Batch size.
        loss: Loss function to be used during training.
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs.
        checkpoint_weight_path: Filepath where weights will be checkpointed to
            during training. This file will be rewritten by the function.
        initial_lr: Initial learning rate. Will only be used for the first
            training step (i.e. the output_layer layer)
        next_lr: Learning rate for every subsequent step.
        seed: Random number generator seed.
        verbose: Verbosity flag.
        evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.

    # Returns:
        Accuracy of the finetuned model.
    """
    if verbose:
        print('Training..')

    # Train using chain-thaw
    train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
                        initial_lr, next_lr, embed_l2, verbose)

    if evaluate == 'acc':
        return evaluate_using_acc(model, test_gen)
    elif evaluate == 'weighted_f1':
        return evaluate_using_weighted_f1(model, test_gen, val_gen)