def train_by_chain_thaw()

in torchmoji/finetuning.py [0:0]


def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
                        initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1):
    """ Finetunes model using the chain-thaw method.

    This is done as follows:
    1) Freeze every layer except the last (output_layer) layer and train it.
    2) Freeze every layer except the first layer and train it.
    3) Freeze every layer except the second etc., until the second last layer.
    4) Unfreeze all layers and train entire model.

    # Arguments:
        model: Model to be trained.
        train_gen: Training sample generator.
        val_data: Validation data.
        loss: Loss function to be used.
        finetuning_args: Training early stopping and checkpoint saving parameters
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs.
        checkpoint_weight_path: Where weight checkpoints should be saved.
        batch_size: Batch size.
        initial_lr: Initial learning rate. Will only be used for the first
            training step (i.e. the output_layer layer)
        next_lr: Learning rate for every subsequent step.
        verbose: Verbosity flag.
    """
    # Get trainable layers
    layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) !=  0]

    # Bring last layer to front
    layers.insert(0, layers.pop(len(layers) - 1))

    # Add None to the end to signify finetuning all layers
    layers.append(None)

    lr = None
    # Finetune each layer one by one and finetune all of them at once
    # at the end
    for layer in layers:
        if lr is None:
            lr = initial_lr
        elif lr == initial_lr:
            lr = next_lr

        # Freeze all except current layer
        for _layer in layers:
            if _layer is not None:
                trainable = _layer == layer or layer is None
                change_trainable(_layer, trainable=trainable, verbose=False)

        # Verify we froze the right layers
        for _layer in model.children():
            assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None

        if verbose:
            if layer is None:
                print('Finetuning all layers')
            else:
                print('Finetuning {}'.format(layer))

        special_params = [id(p) for p in model.embed.parameters()]
        base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
        embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
        adam = optim.Adam([
            {'params': base_params},
            {'params': embed_parameters, 'weight_decay': embed_l2},
            ], lr=lr)

        fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs,
                  checkpoint_path, patience)

        # Reload the best weights found to avoid overfitting
        # Wait a bit to allow proper closing of weights file
        sleep(1)
        model.load_state_dict(torch.load(checkpoint_path))
        if verbose >= 2:
            print("Loaded weights from {}".format(checkpoint_path))