def update_actor_and_alpha()

in svg/agent.py [0:0]


    def update_actor_and_alpha(self, xs, logger, step):
        assert xs.ndimension() == 2
        n_batch, _ = xs.size()

        do_model_free_update = step < self.warmup_steps or \
          self.horizon == 0 or not self.actor_mve or \
          (self.actor_dx_threshold is not None and \
           self.rolling_dx_loss is not None and
           self.rolling_dx_loss > self.actor_dx_threshold)

        if do_model_free_update:
            # Do vanilla SAC updates while the model warms up.
            # i.e., fit to just the Q function
            _, pi, first_log_p = self.actor(xs)
            actor_Q1, actor_Q2 = self.critic(xs, pi)
            actor_Q = torch.min(actor_Q1, actor_Q2)
            actor_loss = (self.temp.alpha.detach() * first_log_p - actor_Q).mean()
        else:
            # Switch to the model-based updates.
            # i.e., fit to the controller's sequence cost
            rewards, first_log_p, total_log_p_us = self.expand_Q(
                xs, self.critic, sample=True, discount=True)
            assert total_log_p_us.size() == rewards.size()
            actor_loss = -(rewards/self.horizon).mean()

        logger.log('train_actor/loss', actor_loss, step)
        logger.log('train_actor/entropy', -first_log_p.mean(), step)

        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        self.actor.log(logger, step)
        self.temp.update(first_log_p, logger, step)

        logger.log('train_alpha/value', self.temp.alpha, step)