def train()

in scripts/adapet/ADAPET/src/train.py [0:0]


def train(config):
    '''
    Trains the model

    :param config:
    :return:
    '''

    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
    batcher = Batcher(config, tokenizer, config.dataset)
    dataset_reader = batcher.get_dataset_reader()
    model = adapet(config, tokenizer, dataset_reader).to(device)

    ### Create Optimizer
    # Ignore weight decay for certain parameters
    no_decay_param = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay_param)],
         'weight_decay': config.weight_decay,
         'lr': config.lr},
        {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay_param)],
         'weight_decay': 0.0,
         'lr': config.lr},
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-8)

    #altered
    #best_dev_acc = 0
    best_dev_acc = -float('inf')
    train_iter = batcher.get_train_batch()
    dict_val_store = None

    # Number of batches is assuming grad_accumulation_factor forms one batch
    tot_num_batches = config.num_batches * config.grad_accumulation_factor

    # Warmup steps and total steps are based on batches, not epochs
    num_warmup_steps = config.num_batches * config.warmup_ratio
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, config.num_batches)

    for i in range(tot_num_batches):
        # Get true batch_idx
        batch_idx = (i // config.grad_accumulation_factor)

        model.train()
        sup_batch = next(train_iter)
        loss, dict_val_update = model(sup_batch)
        loss = loss / config.grad_accumulation_factor
        loss.backward()

        if (i+1) % config.grad_accumulation_factor == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        dict_val_store = update_dict_val_store(dict_val_store, dict_val_update, config.grad_accumulation_factor)
        print("Finished %d batches" % batch_idx, end='\r')

        if (batch_idx + 1) % config.eval_every == 0 and i % config.grad_accumulation_factor == 0:
            dict_avg_val = get_avg_dict_val_store(dict_val_store, config.eval_every)
            dict_val_store = None
            dev_acc, dev_logits = dev_eval(config, model, batcher, batch_idx, dict_avg_val)
            #altered but not used
            if type(dev_acc) == str:
                f1s = re.findall(r"[-+]?\d*\.\d+|\d+", dev_acc)
                dev_acc = float(f1s[0])

            print("Global Step: %d Acc: %.3f" % (batch_idx, float(dev_acc)) + '\n')
            
            if dev_acc > best_dev_acc:
                best_dev_acc = dev_acc
                torch.save(model.state_dict(), os.path.join(config.exp_dir, "best_model.pt"))
                with open(os.path.join(config.exp_dir, "dev_logits.npy"), 'wb') as f:
                    np.save(f, dev_logits)