in train.py [0:0]
def train_step(model, data_loader, device):
acc_loss = 0.0
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
for i in range(data_loader.grad_acc_steps):
# get the next batch
batch = next(data_loader)
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
# disable gradient synchronization for all but the last micro-batch
if requires_grad_sync:
model.require_backward_grad_sync = (i == data_loader.grad_acc_steps - 1)
outputs = model(input_ids=input_ids)
# compute the loss
batch_size, seq_len = input_ids.shape
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
loss = F.cross_entropy(outputs, target_ids, reduction='mean') / data_loader.grad_acc_steps
loss.backward()
acc_loss += loss.item()
return acc_loss