in scripts/adapet/ADAPET/src/train.py [0:0]
def train(config):
'''
Trains the model
:param config:
:return:
'''
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
batcher = Batcher(config, tokenizer, config.dataset)
dataset_reader = batcher.get_dataset_reader()
model = adapet(config, tokenizer, dataset_reader).to(device)
### Create Optimizer
# Ignore weight decay for certain parameters
no_decay_param = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay_param)],
'weight_decay': config.weight_decay,
'lr': config.lr},
{'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay_param)],
'weight_decay': 0.0,
'lr': config.lr},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-8)
#altered
#best_dev_acc = 0
best_dev_acc = -float('inf')
train_iter = batcher.get_train_batch()
dict_val_store = None
# Number of batches is assuming grad_accumulation_factor forms one batch
tot_num_batches = config.num_batches * config.grad_accumulation_factor
# Warmup steps and total steps are based on batches, not epochs
num_warmup_steps = config.num_batches * config.warmup_ratio
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, config.num_batches)
for i in range(tot_num_batches):
# Get true batch_idx
batch_idx = (i // config.grad_accumulation_factor)
model.train()
sup_batch = next(train_iter)
loss, dict_val_update = model(sup_batch)
loss = loss / config.grad_accumulation_factor
loss.backward()
if (i+1) % config.grad_accumulation_factor == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
dict_val_store = update_dict_val_store(dict_val_store, dict_val_update, config.grad_accumulation_factor)
print("Finished %d batches" % batch_idx, end='\r')
if (batch_idx + 1) % config.eval_every == 0 and i % config.grad_accumulation_factor == 0:
dict_avg_val = get_avg_dict_val_store(dict_val_store, config.eval_every)
dict_val_store = None
dev_acc, dev_logits = dev_eval(config, model, batcher, batch_idx, dict_avg_val)
#altered but not used
if type(dev_acc) == str:
f1s = re.findall(r"[-+]?\d*\.\d+|\d+", dev_acc)
dev_acc = float(f1s[0])
print("Global Step: %d Acc: %.3f" % (batch_idx, float(dev_acc)) + '\n')
if dev_acc > best_dev_acc:
best_dev_acc = dev_acc
torch.save(model.state_dict(), os.path.join(config.exp_dir, "best_model.pt"))
with open(os.path.join(config.exp_dir, "dev_logits.npy"), 'wb') as f:
np.save(f, dev_logits)