in CodeXGLUE/Text-Code/text-to-code/code/run.py [0:0]
def train(args, train_dataset, model, tokenizer, fh, pool):
""" Train the model """
if args.local_rank in [-1, 0]:
args.tensorboard_dir = os.path.join(args.output_dir, 'tensorboard')
if not os.path.exists(args.tensorboard_dir):
os.makedirs(args.tensorboard_dir)
tb_writer = SummaryWriter(args.tensorboard_dir)
args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, drop_last=True)
total_examples = len(train_dataset) * (
torch.distributed.get_world_size() if args.local_rank != -1 else 1)
batch_size = args.batch_size * args.gradient_accumulation_steps * (
torch.distributed.get_world_size() if args.local_rank != -1 else 1)
# if args.max_steps > 0:
# t_total = args.max_steps
# args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
if args.num_train_epochs > 0:
t_total = total_examples // batch_size * args.num_train_epochs
args.max_steps = t_total
model.to(args.device)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
num_training_steps=t_total)
checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last')
scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt')
optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt')
if os.path.exists(scheduler_last):
scheduler.load_state_dict(torch.load(scheduler_last, map_location="cpu"))
if os.path.exists(optimizer_last):
optimizer.load_state_dict(torch.load(optimizer_last, map_location="cpu"))
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank%args.gpu_per_node],
output_device=args.local_rank%args.gpu_per_node,
find_unused_parameters=True)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", total_examples )
logger.info(" Num epoch = %d", t_total*batch_size//total_examples)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", batch_size)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = args.start_step
tr_loss, logging_loss,avg_loss,tr_nb = 0.0, 0.0,0.0,0
# model.resize_token_embeddings(len(tokenizer))
model.zero_grad()
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
best_bleu = 0.0
for idx in range(args.start_epoch, int(args.num_train_epochs)):
for step, (batch, token_labels) in enumerate(train_dataloader):
inputs = batch.to(args.device)
attn_mask = torch.tensor(token_labels.clone().detach() != 0, dtype=torch.uint8, device=args.device)
loss_mask = torch.tensor(token_labels.clone().detach() == 2, dtype=torch.uint8, device=args.device)
model.train()
# outputs = model(inputs, attention_mask=attn_mask, labels=inputs, loss_mask=loss_mask)
# loss = outputs[0]
outputs = model(inputs, attention_mask=attn_mask)
logits = outputs[0]
labels = inputs
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
flatten_shift_loss_mask = loss_mask[..., :-1].contiguous().view(-1)
ids = torch.nonzero(flatten_shift_loss_mask).view(-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[ids], shift_labels.view(-1)[ids])
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
global_step += 1
output_flag=True
avg_loss=round(np.exp((tr_loss - logging_loss) /(global_step- tr_nb)),4)
if global_step % args.logging_steps == 0:
logger.info(" steps: %s ppl: %s", global_step, round(avg_loss,5))
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
tb_writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss
tr_nb=global_step
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
checkpoint_prefix = "checkpoint"
# Save model checkpoint
if args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
# results = evaluate(args, model, tokenizer, eval_when_training=True)
# for key, value in results.items():
# tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
# logger.info(" %s = %s", key, round(value,4))
# output_dir = os.path.join(args.output_dir, '{}-{}-{}'.format(checkpoint_prefix, global_step, round(results['perplexity'],4)))
dev_bleu, dev_EM = eval_bleu(args, model, tokenizer, file_type='dev', num=100)
logger.info(f"dev bleu: {dev_bleu}, dev EM: {dev_EM}")
output_dir = os.path.join(args.output_dir, '{}-{}-{}'.format(checkpoint_prefix, global_step, round(dev_bleu,2)))
if dev_bleu > best_bleu:
best_bleu = dev_bleu
logger.info(f"best bleu updated. saved in {output_dir}")
logger.info(f"best bleu: {best_bleu}")
else:
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
# _rotate_checkpoints(args, checkpoint_prefix)
last_output_dir = os.path.join(args.output_dir, 'checkpoint-last')
if not os.path.exists(last_output_dir):
os.makedirs(last_output_dir)
model_to_save.save_pretrained(last_output_dir)
tokenizer.save_pretrained(last_output_dir)
idx_file = os.path.join(last_output_dir, 'idx_file.txt')
with open(idx_file, 'w', encoding='utf-8') as idxf:
idxf.write(str(0) + '\n')
torch.save(optimizer.state_dict(), os.path.join(last_output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(last_output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", last_output_dir)
step_file = os.path.join(last_output_dir, 'step_file.txt')
with open(step_file, 'w', encoding='utf-8') as stepf:
stepf.write(str(global_step) + '\n')
# torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
# torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
# logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
break
if args.max_steps > 0 and global_step > args.max_steps:
break
if args.local_rank in [-1, 0]:
tb_writer.close()
return global_step, tr_loss / global_step