in torchmoji/finetuning.py [0:0]
def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen,
nb_epochs, checkpoint_path, patience=5, evaluate='acc',
verbose=2):
""" Finetunes the given model using the accuracy measure.
# Arguments:
model: Model to be finetuned.
nb_classes: Number of classes in the given dataset.
train: Training data, given as a tuple of (inputs, outputs)
val: Validation data, given as a tuple of (inputs, outputs)
test: Testing data, given as a tuple of (inputs, outputs)
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs.
batch_size: Batch size.
checkpoint_weight_path: Filepath where weights will be checkpointed to
during training. This file will be rewritten by the function.
patience: Patience for callback methods.
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
verbose: Verbosity flag.
# Returns:
Accuracy of the trained model, ONLY if 'evaluate' is set.
"""
if verbose:
print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad]))
print("Training...")
if evaluate == 'acc':
print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen))
elif evaluate == 'weighted_f1':
print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen))
fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience)
# Reload the best weights found to avoid overfitting
# Wait a bit to allow proper closing of weights file
sleep(1)
model.load_state_dict(torch.load(checkpoint_path))
if verbose >= 2:
print("Loaded weights from {}".format(checkpoint_path))
if evaluate == 'acc':
return evaluate_using_acc(model, test_gen)
elif evaluate == 'weighted_f1':
return evaluate_using_weighted_f1(model, test_gen, val_gen)