def train_ft()

in src/run_paraphrase.py [0:0]


def train_ft(roberta, train_data, dev_data, logit_scale, batch_size, learning_rate, num_epochs):
    model = BertScoreModel(roberta, list(range(17, 25)), logit_scale)
    model.to(device=roberta.device)
    model.train()
    train_loader = torch.utils.data.DataLoader(
            torchify_dataset(train_data, roberta.device), 
            batch_size=batch_size, collate_fn=collater, shuffle=True)
    dev_loader = torch.utils.data.DataLoader(
            torchify_dataset(dev_data, roberta.device), 
            batch_size=batch_size, collate_fn=collater, shuffle=False)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, 
                                 betas=(0.9, 0.98), eps=1e-6)
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_num_correct = 0
        dev_loss = 0.0
        dev_num_correct = 0
        # Train
        for ex in train_loader:
            x1, x2, len1, len2, y = ex
            optimizer.zero_grad()
            out = model(x1, x2, len1, len2)
            loss = criterion(out, y)
            train_loss += loss.item()
            train_num_correct += sum(out * (2 * y - 1) > 0)
            loss.backward()
            optimizer.step()
        # Evaluate on dev
        model.eval()
        with torch.no_grad():
            for ex in dev_loader:
                x1, x2, len1, len2, y = ex
                out = model(x1, x2, len1, len2)
                loss = criterion(out, y)
                dev_loss += loss.item()
                dev_num_correct += sum(out * (2 * y - 1) > 0)
        train_acc = 100 * train_num_correct / len(train_data)
        dev_acc = 100 * dev_num_correct / len(dev_data)
        print(f'Epoch {epoch}: train loss={train_loss:.5f}, acc={train_acc:.2f}%; dev loss={dev_loss:.5f}, acc={dev_acc:.2f}%')
    return model