def training_step()

in reagent/training/ranking/seq2slate_trainer.py [0:0]


    def training_step(self, batch: rlt.PreprocessedRankingInput, batch_idx: int):
        assert type(batch) is rlt.PreprocessedRankingInput

        batch_size = batch.state.float_features.shape[0]

        reward = batch.slate_reward
        assert reward is not None

        optimizers = self.optimizers()
        if self.baseline_net:
            assert len(optimizers) == 2
            baseline_opt = optimizers[1]
        else:
            assert len(optimizers) == 1
        rl_opt = optimizers[0]

        if self.baseline_net:
            # Train baseline
            b = self.baseline_net(batch)
            baseline_loss = 1.0 / batch_size * torch.sum((b - reward) ** 2)
            baseline_opt.zero_grad()
            self.manual_backward(baseline_loss)
            baseline_opt.step()
        else:
            b = torch.zeros_like(reward)
            baseline_loss = torch.zeros(1)

        # Train Seq2Slate using REINFORCE
        # log probs of tgt seqs
        model_propensities = torch.exp(
            self.seq2slate_net(
                batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE
            ).log_probs
        )
        b = b.detach()
        assert (
            b.shape == reward.shape == model_propensities.shape
        ), f"{b.shape} {reward.shape} {model_propensities.shape}"

        impt_smpl, clamped_impt_smpl = self._compute_impt_smpl(
            model_propensities, batch.tgt_out_probs
        )
        assert (
            impt_smpl.shape == clamped_impt_smpl.shape == reward.shape
        ), f"{impt_smpl.shape} {clamped_impt_smpl.shape} {reward.shape}"
        # gradient is only w.r.t model_propensities
        assert (
            not reward.requires_grad
            # pyre-fixme[16]: `Optional` has no attribute `requires_grad`.
            and not batch.tgt_out_probs.requires_grad
            and impt_smpl.requires_grad
            and clamped_impt_smpl.requires_grad
            and not b.requires_grad
        )
        # add negative sign because we take gradient descent but we want to
        # maximize rewards
        batch_obj_loss = -clamped_impt_smpl * (reward - b)
        obj_loss = torch.mean(batch_obj_loss)

        # condition to perform policy gradient update:
        # 1. no baseline
        # 2. or baseline is present and it passes the warm up stage
        # 3. the last policy gradient was performed policy_gradient_interval minibatches ago
        if (
            self.baseline_net is None
            or (self.all_batches_processed + 1) >= self.baseline_warmup_num_batches
        ):
            self.manual_backward(obj_loss)
            if (self.all_batches_processed + 1) % self.policy_gradient_interval == 0:
                rl_opt.step()
                rl_opt.zero_grad()
        else:
            logger.info("Not update RL model because now is baseline warmup phase")

        ips_loss = torch.mean(-impt_smpl * reward).cpu().detach().numpy()
        clamped_ips_loss = (
            torch.mean(-clamped_impt_smpl * reward).cpu().detach().numpy()
        )
        baseline_loss = baseline_loss.detach().cpu().numpy().item()
        advantage = (reward - b).detach().cpu().numpy()
        logged_slate_rank_probs = model_propensities.detach().cpu().numpy()

        if (self.all_batches_processed + 1) % self.print_interval == 0:
            logger.info(
                "{} batch: ips_loss={}, clamped_ips_loss={}, baseline_loss={}, max_ips={}, mean_ips={}, grad_update={}".format(
                    self.all_batches_processed + 1,
                    ips_loss,
                    clamped_ips_loss,
                    baseline_loss,
                    torch.max(impt_smpl),
                    torch.mean(impt_smpl),
                    (self.all_batches_processed + 1) % self.policy_gradient_interval
                    == 0,
                )
            )
        self.reporter.log(
            train_ips_score=torch.tensor(ips_loss).reshape(1),
            train_clamped_ips_score=torch.tensor(clamped_ips_loss).reshape(1),
            train_baseline_loss=torch.tensor(baseline_loss).reshape(1),
            train_logged_slate_rank_probs=torch.FloatTensor(logged_slate_rank_probs),
            train_ips_ratio=impt_smpl,
            train_clamped_ips_ratio=clamped_impt_smpl,
            train_advantages=advantage,
        )