in grok/training.py [0:0]
def validation_step(self, batch, batch_idx):
"""
Used by pytorch_lightning
Runs one forward validation pass on one batch
:param batch: The batch of equations to process
:param batch_idx: which batch this is in the epoch.
:returns: a dict with val_loss, val_accuracy, probabilities of solutions,
attentions, and values
"""
if self.next_epoch_to_eval < self.current_epoch:
self.next_epoch_to_eval = self.current_epoch
if self.current_epoch != self.next_epoch_to_eval:
return {}
with torch.no_grad():
loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step(
batch=batch, batch_idx=batch_idx, train=False
)
output = {
"partial_val_loss": coeff * loss,
"partial_val_accuracy": coeff * accuracy,
"y_hat_rhs": y_hat_rhs,
"partial_attentions": attentions,
"partial_values": values,
}
if self.current_epoch == 0:
output["x_lhs"] = x_lhs
return output