def trainLF()

in training/NeuralPatternMatchingTraining.py [0:0]


def trainLF(train,
            model, l1_coeff, optimizer, scheduler, max_epochs, validation=None, device='cpu'):

    loss_function = torch.nn.CrossEntropyLoss()

    model.train()
    print('Training...')

    epochs_loss = []
    for epoch in range(1, max_epochs + 1):  # again, normally you would NOT do 300 epochs, it is toy data

        epoch_losses = []

        b_idx = 0
        for inputs, annotations, mask_input, targets, _ in train:

            b_idx += 1

            inputs, batch = inputs
            annotations, _ = annotations
            mask_input, _ = mask_input
            targets, _ = targets

            inputs = inputs.to(device)
            batch = batch.to(device)
            annotations = annotations.long().to(device)
            mask_input = mask_input.to(device)
            targets = targets.to(device)

            # Reset the gradient after a mini-batch update
            optimizer.zero_grad()

            # Run the forward pass.
            inputs = (inputs, annotations, mask_input, batch)
            out, _ = model(*inputs)

            # Compute the loss, gradients, and update the parameters by calling optimizer.step()
            loss = loss_function(out, targets.long())

            # L1 regularization on linear model to get sparse activations
            l1_norm = torch.norm(model.lin.weight, p=1)
            loss += l1_norm * l1_coeff

            loss.backward()
            optimizer.step()

            epoch_losses.append(float(loss))

            # This solves memory issues when on GPU
            inputs = None
            batch = None
            annotations = None
            targets = None
            out = None
            loss = None

        epoch_avg_loss = sum(epoch_losses) / len(epoch_losses)
        epochs_loss.append(epoch_avg_loss)

        if scheduler is not None:
            scheduler.step()

        if validation is not None and epoch % 10 == 0:
            valid_loss, _, _, _ = computeLossLF(validation, model, reduction='mean')  # default reduction is 'none'
            print(f'Epoch {epoch}, train avg loss is {epoch_avg_loss}, valid avg loss is {valid_loss}')
        elif epoch == 1 or epoch % 10 == 0:
            print(f'Epoch {epoch}, train avg loss is {epoch_avg_loss}')

    return epochs_loss