def training_step()

in sagemaker_notebook_instance/containers/relationship_extraction/package/models.py [0:0]


    def training_step(self, batch, batch_idx):
        token_ids = batch['token_ids']
        attention_mask = batch['attention_mask']
        label_id = batch['label_id']
        output = self.model(token_ids, attention_mask)
        loss = torch.nn.functional.cross_entropy(output, label_id)
        self.log('train_loss', loss)
        self.train_acc(output, label_id)
        self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
        self.train_f1(output, label_id)
        self.log('train_f1', self.train_f1, on_step=True, on_epoch=False)
        return loss