def train_step_gen()

in reagent/training/discrete_crr_trainer.py [0:0]


    def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int):
        """
        IMPORTANT: the input action here is preprocessed according to the
        training_batch type, which in this case is DiscreteDqnInput. Hence,
        the preprocessor in the DiscreteDqnInputMaker class in the
        trainer_preprocessor.py is used, which converts acion taken to a
        one-hot representation.
        """
        self._check_input(training_batch)

        state = training_batch.state
        action = training_batch.action
        next_state = training_batch.next_state
        not_terminal = training_batch.not_terminal
        rewards = self.boost_rewards(training_batch.reward, training_batch.action)

        # Remember: training_batch.action is in the one-hot format
        logged_action_idxs = torch.argmax(action, dim=1, keepdim=True)
        discount_tensor = torch.full_like(rewards, self.gamma)

        next_q_values = self.q1_network_target(next_state)
        target_q_values = self.compute_target_q_values(
            next_state, rewards, not_terminal, next_q_values
        )
        q1_loss = self.compute_td_loss(self.q1_network, state, action, target_q_values)

        # Show td_loss on the progress bar and in tensorboard graphs:
        self.log(
            "td_loss", q1_loss, prog_bar=True, batch_size=training_batch.batch_size()
        )
        yield q1_loss

        if self.q2_network:
            q2_loss = self.compute_td_loss(
                self.q2_network, state, action, target_q_values
            )
            yield q2_loss

        all_q_values = self.q1_network(state)  # Q-values of all actions

        # Note: action_dim (the length of each row of the actor_action
        # matrix obtained below) is assumed to be > 1.
        all_action_scores = self.actor_network(state).action
        logged_action_probs = training_batch.extras.action_probability

        actor_loss_without_reg, actor_loss = self.compute_actor_loss(
            batch_idx, action, logged_action_probs, all_q_values, all_action_scores
        )
        # self.reporter.log(
        #     actor_loss=actor_loss,
        #     actor_q1_value=actor_q1_values,
        # )

        # Show actor_loss on the progress bar and also in Tensorboard graphs
        self.log(
            "actor_loss_without_reg",
            actor_loss_without_reg,
            prog_bar=True,
            batch_size=training_batch.batch_size(),
        )
        self.log(
            "actor_loss",
            actor_loss,
            prog_bar=True,
            batch_size=training_batch.batch_size(),
        )
        yield actor_loss

        yield from self._calculate_cpes(
            training_batch,
            state,
            next_state,
            all_action_scores,
            next_q_values.detach(),
            logged_action_idxs,
            discount_tensor,
            not_terminal.float(),
        )

        # TODO: rename underlying function to get_max_possible_values_and_idxs
        model_action_idxs = self.get_max_q_values(
            all_action_scores,
            training_batch.possible_actions_mask if self.maxq_learning else action,
        )[1]

        self.reporter.log(
            logged_actions=logged_action_idxs,
            td_loss=q1_loss,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_values=all_action_scores,
            model_action_idxs=model_action_idxs,
        )

        # Use the soft update rule to update the target networks.
        # Note: this yield has to be the last one, since SoftUpdate is the last
        # optimizer added in the configure_optimizers() function.
        result = self.soft_update_result()
        yield result