def class_avg_finetune()

in torchmoji/class_avg_finetuning.py [0:0]


def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
                       method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
                       verbose=True):
    """ Compiles and finetunes the given model.

    # Arguments:
        model: Model to be finetuned
        texts: List of three lists, containing tokenized inputs for training,
            validation and testing (in that order).
        labels: List of three lists, containing labels for training,
            validation and testing (in that order).
        nb_classes: Number of classes in the dataset.
        batch_size: Batch size.
        method: Finetuning method to be used. For available methods, see
            FINETUNING_METHODS in global_variables.py. Note that the model
            should be defined accordingly (see docstring for torchmoji_transfer())
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
        embed_l2: L2 regularization for the embedding layer.
        verbose: Verbosity flag.

    # Returns:
        Model after finetuning,
        score after finetuning using the class average F1 metric.
    """

    if method not in FINETUNING_METHODS:
        raise ValueError('ERROR (class_avg_tune_trainable): '
                         'Invalid method parameter. '
                         'Available options: {}'.format(FINETUNING_METHODS))

    (X_train, y_train) = (texts[0], labels[0])
    (X_val, y_val) = (texts[1], labels[1])
    (X_test, y_test) = (texts[2], labels[2])

    checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
                      .format(WEIGHTS_DIR, str(uuid.uuid4()))

    f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
                   .format(WEIGHTS_DIR, str(uuid.uuid4()))

    if method in ['last', 'new']:
        lr = 0.001
    elif method in ['full', 'chain-thaw']:
        lr = 0.0001

    loss_op = nn.BCEWithLogitsLoss()

    # Freeze layers if using last
    if method == 'last':
        model = freeze_layers(model, unfrozen_keyword='output_layer')

    # Define optimizer, for chain-thaw we define it later (after freezing)
    if method == 'last':
        adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
    elif method in ['full', 'new']:
        # Add L2 regulation on embeddings only
        special_params = [id(p) for p in model.embed.parameters()]
        base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
        embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
        adam = optim.Adam([
            {'params': base_params},
            {'params': embed_parameters, 'weight_decay': embed_l2},
            ], lr=lr)

    # Training
    if verbose:
        print('Method:  {}'.format(method))
        print('Classes: {}'.format(nb_classes))

    if method == 'chain-thaw':
        result = class_avg_chainthaw(model, nb_classes=nb_classes,
                                     loss_op=loss_op,
                                     train=(X_train, y_train),
                                     val=(X_val, y_val),
                                     test=(X_test, y_test),
                                     batch_size=batch_size,
                                     epoch_size=epoch_size,
                                     nb_epochs=nb_epochs,
                                     checkpoint_weight_path=checkpoint_path,
                                     f1_init_weight_path=f1_init_path,
                                     verbose=verbose)
    else:
        result = class_avg_tune_trainable(model, nb_classes=nb_classes,
                                          loss_op=loss_op,
                                          optim_op=adam,
                                          train=(X_train, y_train),
                                          val=(X_val, y_val),
                                          test=(X_test, y_test),
                                          epoch_size=epoch_size,
                                          nb_epochs=nb_epochs,
                                          batch_size=batch_size,
                                          init_weight_path=f1_init_path,
                                          checkpoint_weight_path=checkpoint_path,
                                          verbose=verbose)
    return model, result