def train_step()

in step5_data_parallel_naive/train.py [0:0]


def train_step(model, dataloader, device):
    acc_loss = 0.0

    requires_grad_sync = pgm.process_group_manager.dp_world_size > 1

    for i in range(dataloader.grad_acc_steps):
        # get the next batch
        batch = next(dataloader)
        input_ids = batch["input_ids"].to(device)
        target_ids = batch["target_ids"].to(device)

        # enable gradient synchronization for the last micro-batch only
        if requires_grad_sync:
            model.require_backward_grad_sync = (i == dataloader.grad_acc_steps - 1)

        outputs = model(input_ids=input_ids)

        # compute the loss
        batch_size, seq_len = input_ids.shape
        target_ids = target_ids.reshape(-1)
        outputs = outputs.view(seq_len*batch_size, -1)
        loss = F.cross_entropy(outputs, target_ids, reduction='mean') / dataloader.grad_acc_steps

        loss.backward()

        acc_loss += loss.item()

    return acc_loss