in torchmoji/finetuning.py [0:0]
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1):
""" Finetunes given model using chain-thaw and evaluates using accuracy.
# Arguments:
model: Model to be finetuned.
train: Training data, given as a tuple of (inputs, outputs)
val: Validation data, given as a tuple of (inputs, outputs)
test: Testing data, given as a tuple of (inputs, outputs)
batch_size: Batch size.
loss: Loss function to be used during training.
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs.
checkpoint_weight_path: Filepath where weights will be checkpointed to
during training. This file will be rewritten by the function.
initial_lr: Initial learning rate. Will only be used for the first
training step (i.e. the output_layer layer)
next_lr: Learning rate for every subsequent step.
seed: Random number generator seed.
verbose: Verbosity flag.
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
# Returns:
Accuracy of the finetuned model.
"""
if verbose:
print('Training..')
# Train using chain-thaw
train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
initial_lr, next_lr, embed_l2, verbose)
if evaluate == 'acc':
return evaluate_using_acc(model, test_gen)
elif evaluate == 'weighted_f1':
return evaluate_using_weighted_f1(model, test_gen, val_gen)