def finetune()

in torchmoji/finetuning.py [0:0]


def finetune(model, texts, labels, nb_classes, batch_size, method,
             metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
             verbose=1):
    """ Compiles and finetunes the given pytorch model.

    # Arguments:
        model: Model to be finetuned
        texts: List of three lists, containing tokenized inputs for training,
            validation and testing (in that order).
        labels: List of three lists, containing labels for training,
            validation and testing (in that order).
        nb_classes: Number of classes in the dataset.
        batch_size: Batch size.
        method: Finetuning method to be used. For available methods, see
            FINETUNING_METHODS in global_variables.py.
        metric: Evaluation metric to be used. For available metrics, see
            FINETUNING_METRICS in global_variables.py.
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
        embed_l2: L2 regularization for the embedding layer.
        verbose: Verbosity flag.

    # Returns:
        Model after finetuning,
        score after finetuning using the provided metric.
    """

    if method not in FINETUNING_METHODS:
        raise ValueError('ERROR (finetune): Invalid method parameter. '
                         'Available options: {}'.format(FINETUNING_METHODS))
    if metric not in FINETUNING_METRICS:
        raise ValueError('ERROR (finetune): Invalid metric parameter. '
                         'Available options: {}'.format(FINETUNING_METRICS))

    train_gen = get_data_loader(texts[0], labels[0], batch_size,
                                extended_batch_sampler=True, epoch_size=epoch_size)
    val_gen = get_data_loader(texts[1], labels[1], batch_size,
                              extended_batch_sampler=False)
    test_gen = get_data_loader(texts[2], labels[2], batch_size,
                              extended_batch_sampler=False)

    checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
                      .format(WEIGHTS_DIR, str(uuid.uuid4()))

    if method in ['last', 'new']:
        lr = 0.001
    elif method in ['full', 'chain-thaw']:
        lr = 0.0001

    loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \
         else nn.CrossEntropyLoss()

    # Freeze layers if using last
    if method == 'last':
        model = freeze_layers(model, unfrozen_keyword='output_layer')

    # Define optimizer, for chain-thaw we define it later (after freezing)
    if method == 'last':
        adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
    elif method in ['full', 'new']:
        # Add L2 regulation on embeddings only
        embed_params_id = [id(p) for p in model.embed.parameters()]
        output_layer_params_id = [id(p) for p in model.output_layer.parameters()]
        base_params = [p for p in model.parameters()
                       if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad]
        embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad]
        output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad]
        adam = optim.Adam([
            {'params': base_params},
            {'params': embed_params, 'weight_decay': embed_l2},
            {'params': output_layer_params, 'lr': 0.001},
            ], lr=lr)

    # Training
    if verbose:
        print('Method:  {}'.format(method))
        print('Metric:  {}'.format(metric))
        print('Classes: {}'.format(nb_classes))

    if method == 'chain-thaw':
        result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2,
                            evaluate=metric, verbose=verbose)
    else:
        result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path,
                                evaluate=metric, verbose=verbose)
    return model, result