in torchmoji/finetuning.py [0:0]
def finetune(model, texts, labels, nb_classes, batch_size, method,
metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
verbose=1):
""" Compiles and finetunes the given pytorch model.
# Arguments:
model: Model to be finetuned
texts: List of three lists, containing tokenized inputs for training,
validation and testing (in that order).
labels: List of three lists, containing labels for training,
validation and testing (in that order).
nb_classes: Number of classes in the dataset.
batch_size: Batch size.
method: Finetuning method to be used. For available methods, see
FINETUNING_METHODS in global_variables.py.
metric: Evaluation metric to be used. For available metrics, see
FINETUNING_METRICS in global_variables.py.
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
embed_l2: L2 regularization for the embedding layer.
verbose: Verbosity flag.
# Returns:
Model after finetuning,
score after finetuning using the provided metric.
"""
if method not in FINETUNING_METHODS:
raise ValueError('ERROR (finetune): Invalid method parameter. '
'Available options: {}'.format(FINETUNING_METHODS))
if metric not in FINETUNING_METRICS:
raise ValueError('ERROR (finetune): Invalid metric parameter. '
'Available options: {}'.format(FINETUNING_METRICS))
train_gen = get_data_loader(texts[0], labels[0], batch_size,
extended_batch_sampler=True, epoch_size=epoch_size)
val_gen = get_data_loader(texts[1], labels[1], batch_size,
extended_batch_sampler=False)
test_gen = get_data_loader(texts[2], labels[2], batch_size,
extended_batch_sampler=False)
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
.format(WEIGHTS_DIR, str(uuid.uuid4()))
if method in ['last', 'new']:
lr = 0.001
elif method in ['full', 'chain-thaw']:
lr = 0.0001
loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \
else nn.CrossEntropyLoss()
# Freeze layers if using last
if method == 'last':
model = freeze_layers(model, unfrozen_keyword='output_layer')
# Define optimizer, for chain-thaw we define it later (after freezing)
if method == 'last':
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
elif method in ['full', 'new']:
# Add L2 regulation on embeddings only
embed_params_id = [id(p) for p in model.embed.parameters()]
output_layer_params_id = [id(p) for p in model.output_layer.parameters()]
base_params = [p for p in model.parameters()
if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad]
embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad]
output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad]
adam = optim.Adam([
{'params': base_params},
{'params': embed_params, 'weight_decay': embed_l2},
{'params': output_layer_params, 'lr': 0.001},
], lr=lr)
# Training
if verbose:
print('Method: {}'.format(method))
print('Metric: {}'.format(metric))
print('Classes: {}'.format(nb_classes))
if method == 'chain-thaw':
result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2,
evaluate=metric, verbose=verbose)
else:
result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path,
evaluate=metric, verbose=verbose)
return model, result