def train_step_gen()

in reagent/training/sac_trainer.py [0:0]


    def train_step_gen(self, training_batch: rlt.PolicyNetworkInput, batch_idx: int):
        """
        IMPORTANT: the input action here is assumed to match the
        range of the output of the actor.
        """

        assert isinstance(training_batch, rlt.PolicyNetworkInput)

        state = training_batch.state
        action = training_batch.action
        reward = training_batch.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = training_batch.not_terminal

        #
        # First, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        if self.value_network is not None:
            next_state_value = self.value_network_target(training_batch.next_state)
        else:
            next_state_actor_output = self.actor_network(training_batch.next_state)
            next_state_actor_action = (
                training_batch.next_state,
                rlt.FeatureData(next_state_actor_output.action),
            )
            next_state_value = self.q1_network_target(*next_state_actor_action)

            if self.q2_network is not None:
                target_q2_value = self.q2_network_target(*next_state_actor_action)
                next_state_value = torch.min(next_state_value, target_q2_value)

            log_prob_a = self.actor_network.get_log_prob(
                training_batch.next_state, next_state_actor_output.action
            ).clamp(LOG_PROB_MIN, LOG_PROB_MAX)
            next_state_value -= self.entropy_temperature * log_prob_a

        if self.gamma > 0.0:
            target_q_value = (
                reward + discount * next_state_value * not_done_mask.float()
            )
        else:
            # This is useful in debugging instability issues
            target_q_value = reward

        q1_value = self.q1_network(state, action)
        q1_loss = F.mse_loss(q1_value, target_q_value)
        yield q1_loss

        if self.q2_network:
            q2_value = self.q2_network(state, action)
            q2_loss = F.mse_loss(q2_value, target_q_value)
            yield q2_loss

        # Second, optimize the actor; minimizing KL-divergence between
        # propensity & softmax of value.  Due to reparameterization trick,
        # it ends up being log_prob(actor_action) - Q(s, actor_action)

        actor_output = self.actor_network(state)

        state_actor_action = (state, rlt.FeatureData(actor_output.action))
        q1_actor_value = self.q1_network(*state_actor_action)
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(*state_actor_action)
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_log_prob = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)

        if not self.backprop_through_log_prob:
            actor_log_prob = actor_log_prob.detach()

        if self.crr_config is not None:
            cur_value = self.value_network(training_batch.state)
            advantage = (min_q_actor_value - cur_value).detach()
            # pyre-fixme[16]: `Optional` has no attribute `get_weight_from_advantage`.
            crr_weight = self.crr_config.get_weight_from_advantage(advantage)
            assert (
                actor_log_prob.shape == crr_weight.shape
            ), f"{actor_log_prob.shape} != {crr_weight.shape}"
            actor_loss = -(actor_log_prob * crr_weight.detach())
        else:
            actor_loss = self.entropy_temperature * actor_log_prob - min_q_actor_value
        # Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()

        if self.add_kld_to_loss:
            if self.apply_kld_on_mean:
                action_batch_m = torch.mean(actor_output.squashed_mean, axis=0)
                action_batch_v = torch.var(actor_output.squashed_mean, axis=0)
            else:
                action_batch_m = torch.mean(actor_output.action, axis=0)
                action_batch_v = torch.var(actor_output.action, axis=0)
            kld = (
                0.5
                * (
                    (action_batch_v + (action_batch_m - self.action_emb_mean) ** 2)
                    / self.action_emb_variance
                    - 1
                    + self.action_emb_variance.log()
                    - action_batch_v.log()
                ).sum()
            )

            actor_loss_mean += self.kld_weight * kld

        yield actor_loss_mean

        # Optimize Alpha
        if self.alpha_optimizer is not None:
            alpha_loss = -(
                (
                    self.log_alpha
                    * (
                        actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
                        + self.target_entropy
                    ).detach()
                ).mean()
            )
            yield alpha_loss
            self.entropy_temperature = self.log_alpha.exp()

        #
        # Lastly, if applicable, optimize value network; minimizing MSE between
        # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
        #

        if self.value_network is not None:
            state_value = self.value_network(state)

            if self.logged_action_uniform_prior:
                log_prob_a = torch.zeros_like(min_q_actor_value)
                target_value = min_q_actor_value
            else:
                log_prob_a = actor_output.log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
                target_value = min_q_actor_value - self.entropy_temperature * log_prob_a

            value_loss = F.mse_loss(state_value, target_value.detach())
            yield value_loss

        self.logger.log_metrics(
            {
                "td_loss": q1_loss,
                "logged_rewards": reward.mean(),
                "model_values_on_logged_actions": q1_value.mean(),
                "q1_value": q1_value.mean(),
                "entropy_temperature": self.entropy_temperature,
                "log_prob_a": log_prob_a.mean(),
                "next_state_value": next_state_value.mean(),
                "target_q_value": target_q_value.mean(),
                "min_q_actor_value": min_q_actor_value.mean(),
                "actor_output_log_prob": actor_output.log_prob.mean(),
                "actor_loss": actor_loss.mean(),
            },
            step=self.all_batches_processed,
        )
        if self.q2_network:
            self.logger.log_metrics(
                {"q2_value": q2_value.mean()},
                step=self.all_batches_processed,
            )

        if self.value_network:
            self.logger.log_metrics(
                {"target_state_value": target_value.mean()},
                step=self.all_batches_processed,
            )

        if self.add_kld_to_loss:
            self.logger.log_metrics(
                {
                    "action_batch_mean": action_batch_m.mean(),
                    "action_batch_var": action_batch_v.mean(),
                    # pyre-fixme[61]: `kld` may not be initialized here.
                    "kld": kld,
                },
                step=self.all_batches_processed,
            )

        # Use the soft update rule to update the target networks
        result = self.soft_update_result()
        self.log("td_loss", q1_loss, prog_bar=True)
        yield result