in step5_data_parallel_naive/train.py [0:0]
def train_step(model, dataloader, device):
acc_loss = 0.0
requires_grad_sync = pgm.process_group_manager.dp_world_size > 1
for i in range(dataloader.grad_acc_steps):
# get the next batch
batch = next(dataloader)
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
# enable gradient synchronization for the last micro-batch only
if requires_grad_sync:
model.require_backward_grad_sync = (i == dataloader.grad_acc_steps - 1)
outputs = model(input_ids=input_ids)
# compute the loss
batch_size, seq_len = input_ids.shape
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
loss = F.cross_entropy(outputs, target_ids, reduction='mean') / dataloader.grad_acc_steps
loss.backward()
acc_loss += loss.item()
return acc_loss