in sample_workloads/lit-gpt-demo/openwebtext.py [0:0]
def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
val_iter = iter(val_dataloader)
losses = torch.zeros(max_iters, device=fabric.device)
for k in range(max_iters):
input_ids, targets = next(val_iter)
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits, targets, chunk_size=0)
out = losses.mean()
model.train()
return out