in torchmoji/class_avg_finetuning.py [0:0]
def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size,
epoch_size, nb_epochs, checkpoint_weight_path,
f1_init_weight_path, patience=5,
initial_lr=0.001, next_lr=0.0001, verbose=True):
""" Finetunes given model using chain-thaw and evaluates using F1.
For a dataset with multiple classes, the model is trained once for
each class, relabeling those classes into a binary classification task.
The result is an average of all F1 scores for each class.
# Arguments:
model: Model to be finetuned.
nb_classes: Number of classes in the given dataset.
train: Training data, given as a tuple of (inputs, outputs)
val: Validation data, given as a tuple of (inputs, outputs)
test: Testing data, given as a tuple of (inputs, outputs)
batch_size: Batch size.
loss: Loss function to be used during training.
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs.
checkpoint_weight_path: Filepath where weights will be checkpointed to
during training. This file will be rewritten by the function.
f1_init_weight_path: Filepath where weights will be saved to and
reloaded from before training each class. This ensures that
each class is trained independently. This file will be rewritten.
initial_lr: Initial learning rate. Will only be used for the first
training step (i.e. the softmax layer)
next_lr: Learning rate for every subsequent step.
seed: Random number generator seed.
verbose: Verbosity flag.
# Returns:
Averaged F1 score.
"""
# Unpack args
X_train, y_train = train
X_val, y_val = val
X_test, y_test = test
total_f1 = 0
nb_iter = nb_classes if nb_classes > 2 else 1
torch.save(model.state_dict(), f1_init_weight_path)
for i in range(nb_iter):
if verbose:
print('Iteration number {}/{}'.format(i+1, nb_iter))
model.load_state_dict(torch.load(f1_init_weight_path))
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
y_test, i, nb_classes)
train_gen, X_val_resamp, y_val_resamp = \
prepare_generators(X_train, y_train_new, X_val, y_val_new,
batch_size, epoch_size)
if verbose:
print("Training..")
# Train using chain-thaw
train_by_chain_thaw(model=model, train_gen=train_gen,
val_gen=[(X_val_resamp, y_val_resamp)],
loss_op=loss_op, patience=patience,
nb_epochs=nb_epochs,
checkpoint_path=checkpoint_weight_path,
initial_lr=initial_lr, next_lr=next_lr,
verbose=verbose)
# Evaluate
y_pred_val = model(X_val).cpu().numpy()
y_pred_test = model(X_test).cpu().numpy()
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
y_test_new, y_pred_test)
if verbose:
print('f1_test: {}'.format(f1_test))
print('best_t: {}'.format(best_t))
total_f1 += f1_test
return total_f1 / nb_iter