def main()

in model/disambiguate/train_model.py [0:0]


def main(args):
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.padding_side = "left"
    # Define PAD Token = EOS Token = 50256
    tokenizer.pad_token = tokenizer.eos_token
    num_added_tokenss = tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<USER>", "SYS>"]}
    )
    # Dataloader.
    train_loader = Dataloader(tokenizer, args["train_file"], args)
    val_loader = Dataloader(tokenizer, args["dev_file"], args)
    devtest_loader = Dataloader(tokenizer, args["devtest_file"], args)
    teststd_loader = Dataloader(
        tokenizer, args["teststd_file"], args, hidden_labels=True
    )
    model = Disambiguator(tokenizer, args)

    model.train()
    # loss function.
    criterion = nn.CrossEntropyLoss()
    # Prepare optimizer and schedule (linear warmup and decay)
    optimizer = AdamW(
        model.parameters(), lr=args["learning_rate"], eps=args["adam_epsilon"]
    )

    total_steps = (
        int(train_loader.num_instances / args["batch_size"] * args["num_epochs"]) + 1
    )
    num_iters_epoch = train_loader.num_instances // args["batch_size"]
    num_iters_epoch_float = train_loader.num_instances / args["batch_size"]
    next_eval_iter = 0
    num_iters = 0
    best_performance = {"dev": 0.}
    total_loss = None
    while True:
        epoch = num_iters / (float(train_loader.num_instances) / args["batch_size"])

        batch = train_loader.get_random_batch(args["batch_size"])
        output = model(batch)
        loss = criterion(output, batch["gt_label"])

        if total_loss:
            total_loss = 0.95 * total_loss + 0.05 * loss.item()
        else:
            total_loss = loss.item()

        if num_iters % 100 == 0:
            print(f"[Ep: {epoch:.2f}][Loss: {total_loss:.2f}]")

        loss.backward()
        optimizer.step()
        model.zero_grad()

        # Evaluate_model every epoch.
        if num_iters == next_eval_iter:
            model.eval()
            print("Evaluating ..")
            # Get dev results.
            accuracy = evaluate_model(model, val_loader, args["batch_size"] * 5)
            print(f"Accuracy [dev]: {accuracy}")

            # Evaluate on devtest and teststd if better dev performance.
            if best_performance["dev"] < accuracy:
                best_performance["dev"] = accuracy
                best_performance["iter_id"] = num_iters
                best_performance["epoch"] = epoch

                # Get devtest results.
                if args["result_save_path"]:
                    save_path = os.path.join(
                        args["result_save_path"], f"results_devtest_{num_iters}.json"
                    )
                else:
                    save_path = None
                accuracy = evaluate_model(
                    model, devtest_loader, args["batch_size"] * 5, save_path
                )
                best_performance["devtest"] = accuracy
                # Check if performance is the best.
                print(f"Accuracy [devtest]: {accuracy}")

                # Get teststd predictions.
                if args["result_save_path"]:
                    save_path = os.path.join(
                        args["result_save_path"], f"results_teststd_{num_iters}.json"
                    )
                else:
                    save_path = None
                accuracy = evaluate_model(
                    model, teststd_loader, args["batch_size"] * 5, save_path, hidden_test=True
                )
                best_performance["teststd"] = accuracy
                print(f"Accuracy [teststd]: {accuracy}")
                print(f"Current best performance: {best_performance}")
            model.train()

        num_iters += 1
        next_eval_iter = int(int(epoch + 1) * num_iters_epoch_float)
        if epoch > args["num_epochs"]:
            break