in torchmoji/finetuning.py [0:0]
def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1):
""" Finetunes model using the chain-thaw method.
This is done as follows:
1) Freeze every layer except the last (output_layer) layer and train it.
2) Freeze every layer except the first layer and train it.
3) Freeze every layer except the second etc., until the second last layer.
4) Unfreeze all layers and train entire model.
# Arguments:
model: Model to be trained.
train_gen: Training sample generator.
val_data: Validation data.
loss: Loss function to be used.
finetuning_args: Training early stopping and checkpoint saving parameters
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs.
checkpoint_weight_path: Where weight checkpoints should be saved.
batch_size: Batch size.
initial_lr: Initial learning rate. Will only be used for the first
training step (i.e. the output_layer layer)
next_lr: Learning rate for every subsequent step.
verbose: Verbosity flag.
"""
# Get trainable layers
layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0]
# Bring last layer to front
layers.insert(0, layers.pop(len(layers) - 1))
# Add None to the end to signify finetuning all layers
layers.append(None)
lr = None
# Finetune each layer one by one and finetune all of them at once
# at the end
for layer in layers:
if lr is None:
lr = initial_lr
elif lr == initial_lr:
lr = next_lr
# Freeze all except current layer
for _layer in layers:
if _layer is not None:
trainable = _layer == layer or layer is None
change_trainable(_layer, trainable=trainable, verbose=False)
# Verify we froze the right layers
for _layer in model.children():
assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None
if verbose:
if layer is None:
print('Finetuning all layers')
else:
print('Finetuning {}'.format(layer))
special_params = [id(p) for p in model.embed.parameters()]
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
adam = optim.Adam([
{'params': base_params},
{'params': embed_parameters, 'weight_decay': embed_l2},
], lr=lr)
fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs,
checkpoint_path, patience)
# Reload the best weights found to avoid overfitting
# Wait a bit to allow proper closing of weights file
sleep(1)
model.load_state_dict(torch.load(checkpoint_path))
if verbose >= 2:
print("Loaded weights from {}".format(checkpoint_path))