in torchmoji/class_avg_finetuning.py [0:0]
def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
verbose=True):
""" Compiles and finetunes the given model.
# Arguments:
model: Model to be finetuned
texts: List of three lists, containing tokenized inputs for training,
validation and testing (in that order).
labels: List of three lists, containing labels for training,
validation and testing (in that order).
nb_classes: Number of classes in the dataset.
batch_size: Batch size.
method: Finetuning method to be used. For available methods, see
FINETUNING_METHODS in global_variables.py. Note that the model
should be defined accordingly (see docstring for torchmoji_transfer())
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
embed_l2: L2 regularization for the embedding layer.
verbose: Verbosity flag.
# Returns:
Model after finetuning,
score after finetuning using the class average F1 metric.
"""
if method not in FINETUNING_METHODS:
raise ValueError('ERROR (class_avg_tune_trainable): '
'Invalid method parameter. '
'Available options: {}'.format(FINETUNING_METHODS))
(X_train, y_train) = (texts[0], labels[0])
(X_val, y_val) = (texts[1], labels[1])
(X_test, y_test) = (texts[2], labels[2])
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
.format(WEIGHTS_DIR, str(uuid.uuid4()))
f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
.format(WEIGHTS_DIR, str(uuid.uuid4()))
if method in ['last', 'new']:
lr = 0.001
elif method in ['full', 'chain-thaw']:
lr = 0.0001
loss_op = nn.BCEWithLogitsLoss()
# Freeze layers if using last
if method == 'last':
model = freeze_layers(model, unfrozen_keyword='output_layer')
# Define optimizer, for chain-thaw we define it later (after freezing)
if method == 'last':
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
elif method in ['full', 'new']:
# Add L2 regulation on embeddings only
special_params = [id(p) for p in model.embed.parameters()]
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
adam = optim.Adam([
{'params': base_params},
{'params': embed_parameters, 'weight_decay': embed_l2},
], lr=lr)
# Training
if verbose:
print('Method: {}'.format(method))
print('Classes: {}'.format(nb_classes))
if method == 'chain-thaw':
result = class_avg_chainthaw(model, nb_classes=nb_classes,
loss_op=loss_op,
train=(X_train, y_train),
val=(X_val, y_val),
test=(X_test, y_test),
batch_size=batch_size,
epoch_size=epoch_size,
nb_epochs=nb_epochs,
checkpoint_weight_path=checkpoint_path,
f1_init_weight_path=f1_init_path,
verbose=verbose)
else:
result = class_avg_tune_trainable(model, nb_classes=nb_classes,
loss_op=loss_op,
optim_op=adam,
train=(X_train, y_train),
val=(X_val, y_val),
test=(X_test, y_test),
epoch_size=epoch_size,
nb_epochs=nb_epochs,
batch_size=batch_size,
init_weight_path=f1_init_path,
checkpoint_weight_path=checkpoint_path,
verbose=verbose)
return model, result