Dassl.pytorch/dassl/engine/ssl/sup_baseline.py (23 lines of code) (raw):

from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy @TRAINER_REGISTRY.register() class SupBaseline(TrainerXU): """Supervised Baseline.""" def forward_backward(self, batch_x, batch_u): input, label = self.parse_batch_train(batch_x, batch_u) output = self.model(input) loss = F.cross_entropy(output, label) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input = batch_x["img"] label = batch_x["label"] input = input.to(self.device) label = label.to(self.device) return input, label