src/loss.py [54:60]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    initial_pred_shape = pred_logits.shape[:-1]
    pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)])
    labels = torch.reshape(labels, [-1]) #same as flatten
    all_loss = ce_loss_no_reduction(pred_logits, labels)
    all_loss = all_loss.view(initial_pred_shape)

    return all_loss, n_tokens
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/loss.py [88:93]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    initial_pred_shape = pred_logits.shape[:-1]
    pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)])
    labels = torch.reshape(labels, [-1]) #same as flatten
    all_loss = ce_loss_no_reduction(pred_logits, labels)
    all_loss = all_loss.view(initial_pred_shape)
    return all_loss, n_tokens
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



