def validation_step()

in sagemaker_notebook_instance/containers/relationship_extraction/package/models.py [0:0]


    def validation_step(self, batch, batch_idx):
        token_ids = batch['token_ids']
        attention_mask = batch['attention_mask']
        label_id = batch['label_id']
        output = self.model(token_ids, attention_mask)
        loss = torch.nn.functional.cross_entropy(output, label_id)
        self.log('valid_loss', loss)
        self.valid_acc(output, label_id)
        self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
        self.valid_f1(output, label_id)
        self.log('valid_f1', self.valid_f1, on_step=True, on_epoch=True)
        return loss