def _train()

in mdr/retrieval/single_trainer.py [0:0]


    def _train(self) -> Optional[float]:
        job_env = submitit.JobEnvironment()

        loss_fct = CrossEntropyLoss()
        batch_step = 0 # forward batch count
        best_mrr = 0
        train_loss_meter = AverageMeter()
        print(f"Start training", flush=True)
        # Start from the loaded epoch
        start_epoch = self._state.epoch
        global_step = self._state.global_step
        for epoch in range(start_epoch, self._train_cfg.num_train_epochs):
            print(f"Start epoch {epoch}", flush=True)
            self._state.model.train()
            self._state.epoch = epoch

            for batch in self._train_loader:
                batch_step += 1
                batch = move_to_cuda(batch)
                outputs = self._state.model(batch)
                q = outputs['q']
                c = outputs['c']
                neg_c = outputs['neg_c']
                product_in_batch = torch.mm(q, c.t())
                product_neg = (q * neg_c).sum(-1).unsqueeze(1)
                product = torch.cat([product_in_batch, product_neg], dim=-1)
                target = torch.arange(product.size(0)).to(product.device)
                loss = loss_fct(product, target)

                if self._train_cfg.gradient_accumulation_steps > 1:
                    loss = loss / self._train_cfg.gradient_accumulation_steps
                if self._train_cfg.fp16:
                    with amp.scale_loss(loss, self._state.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                
                train_loss_meter.update(loss.item())
                self.tb_logger.add_scalar('batch_train_loss',
                                     loss.item(), global_step)
                self.tb_logger.add_scalar('smoothed_train_loss',
                                     train_loss_meter.avg, global_step)

                if (batch_step + 1) % self._train_cfg.gradient_accumulation_steps == 0:
                    if self._train_cfg.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self._state.optimizer), self._train_cfg.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self._state.model.parameters(), self._train_cfg.max_grad_norm)
                    self._state.optimizer.step()    # We have accumulated enought gradients
                    self._state.model.zero_grad()
                    global_step += 1
                    self._state.global_step = global_step

            # Checkpoint only on the master
            # if job_env.global_rank == 0:
            self.checkpoint(rm_init=False)
            mrr = self._eval()
            self.tb_logger.add_scalar('dev_mrr', mrr*100, epoch)
            self._state.lr_scheduler.step(mrr)
            if best_mrr < mrr:
                print("Saving model with best MRR %.2f -> MRR %.2f on epoch=%d" % (best_mrr*100, mrr*100, epoch))
                torch.save(self._state.model.state_dict(), os.path.join(self._train_cfg.output_dir, f"checkpoint_best.pt"))
                best_mrr = mrr
            self.log({
                "best_mrr": best_mrr,
                "curr_mrr": mrr,
                "smoothed_loss": train_loss_meter.avg,
                "epoch": epoch
            })
        return best_mrr