in sagemaker_notebook_instance/containers/relationship_extraction/package/models.py [0:0]
def training_step(self, batch, batch_idx):
token_ids = batch['token_ids']
attention_mask = batch['attention_mask']
label_id = batch['label_id']
output = self.model(token_ids, attention_mask)
loss = torch.nn.functional.cross_entropy(output, label_id)
self.log('train_loss', loss)
self.train_acc(output, label_id)
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
self.train_f1(output, label_id)
self.log('train_f1', self.train_f1, on_step=True, on_epoch=False)
return loss