def train_step_gen()

in reagent/training/slate_q_trainer.py [0:0]


    def train_step_gen(self, training_batch: rlt.SlateQInput, batch_idx: int):
        assert isinstance(
            training_batch, rlt.SlateQInput
        ), f"learning input is a {type(training_batch)}"

        reward = training_batch.reward
        reward_mask = training_batch.reward_mask

        discount_tensor = torch.full_like(reward, self.gamma)

        # Adjust the discount factor by the time_diff if the discount_time_scale is provided,
        # and the time_diff exists in the training_batch.
        if self.discount_time_scale and training_batch.time_diff is not None:
            discount_tensor = discount_tensor ** (
                training_batch.time_diff / self.discount_time_scale
            )

        next_action = (
            self._get_maxq_next_action(training_batch.next_state)
            if self.rl_parameters.maxq_learning
            else training_batch.next_action
        )

        terminal_mask = (training_batch.not_terminal.to(torch.bool) == False).squeeze(1)
        next_action_docs = self._action_docs(
            training_batch.next_state,
            next_action,
            terminal_mask=terminal_mask,
        )
        next_q_values = torch.sum(
            self._get_unmasked_q_values(
                self.q_network_target,
                training_batch.next_state,
                next_action_docs,
            )
            * self._get_docs_value(next_action_docs),
            dim=1,
            keepdim=True,
        )

        # If not single selection, divide max-Q by the actual slate size.
        if not self.single_selection:
            next_q_values = next_q_values / self._get_avg_by_slate_size(training_batch)

        filtered_max_q_vals = next_q_values * training_batch.not_terminal.float()
        target_q_values = reward + (discount_tensor * filtered_max_q_vals)
        # Don't mask if not single selection
        if self.single_selection:
            target_q_values = target_q_values[reward_mask]

        # Get Q-value of action taken
        action_docs = self._action_docs(training_batch.state, training_batch.action)
        q_values = self._get_unmasked_q_values(
            self.q_network, training_batch.state, action_docs
        )
        if self.single_selection:
            q_values = q_values[reward_mask]

        all_action_scores = q_values.detach()

        value_loss = F.mse_loss(q_values, target_q_values)
        yield value_loss

        if not self.single_selection:
            all_action_scores = all_action_scores.sum(dim=1, keepdim=True)

        # Logging at the end to schedule all the cuda operations first
        self.reporter.log(
            td_loss=value_loss,
            model_values_on_logged_actions=all_action_scores,
        )

        # Use the soft update rule to update the target networks
        result = self.soft_update_result()
        self.log(
            "td_loss", value_loss, prog_bar=True, batch_size=training_batch.batch_size()
        )
        yield result