mdr/qa/qa_trainer.py [229:253]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                if self._train_cfg.gradient_accumulation_steps > 1:
                    loss = loss / self._train_cfg.gradient_accumulation_steps
                if self._train_cfg.fp16:
                    with amp.scale_loss(loss, self._state.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_loss_meter.update(loss.item())
                if (batch_step + 1) % self._train_cfg.gradient_accumulation_steps == 0:
                    if self._train_cfg.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self._state.optimizer), self._train_cfg.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self._state.model.parameters(), self._train_cfg.max_grad_norm)
                    self._state.optimizer.step()
                    self._state.lr_scheduler.step()
                    self._state.model.zero_grad()
                    global_step += 1
                    self._state.global_step = global_step

                    self.tb_logger.add_scalar('batch_train_loss',
                                        loss.item(), global_step)
                    self.tb_logger.add_scalar('smoothed_train_loss',
                                        train_loss_meter.avg, global_step)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mdr/retrieval/mhop_trainer.py [223:250]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                if self._train_cfg.gradient_accumulation_steps > 1:
                    loss = loss / self._train_cfg.gradient_accumulation_steps
                if self._train_cfg.fp16:
                    with amp.scale_loss(loss, self._state.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                
                train_loss_meter.update(loss.item())

                if (batch_step + 1) % self._train_cfg.gradient_accumulation_steps == 0:
                    if self._train_cfg.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self._state.optimizer), self._train_cfg.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self._state.model.parameters(), self._train_cfg.max_grad_norm)
                    self._state.optimizer.step()
                    self._state.lr_scheduler.step()
                    self._state.model.zero_grad()

                    global_step += 1
                    self._state.global_step = global_step

                    self.tb_logger.add_scalar('batch_train_loss',
                                        loss.item(), global_step)
                    self.tb_logger.add_scalar('smoothed_train_loss',
                                        train_loss_meter.avg, global_step)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



