def update_critic()

in svg/agent.py [0:0]


    def update_critic(self, xs, xps, us, rs, not_done, logger, step):
        assert xs.ndimension() == 2
        n_batch, _ = xs.size()
        rs = rs.squeeze()
        not_done = not_done.squeeze()

        with torch.no_grad():
            if not self.critic_target_mve or step < self.warmup_steps:
                mu, target_us, log_pi = self.actor.forward(
                    xps, compute_pi=True, compute_log_pi=True)
                log_pi = log_pi.squeeze(1)

                target_Q1, target_Q2 = [
                    Q.squeeze(1) for Q in self.critic_target(xps, target_us)]
                target_Q = torch.min(target_Q1, target_Q2) - self.temp.alpha.detach() * log_pi
                assert target_Q.size() == rs.size()
                assert target_Q.ndimension() == 1
                target_Q = rs + not_done * self.discount * target_Q
                target_Q = target_Q.detach()
            else:
                target_Q, first_log_p, total_log_p_us = self.expand_Q(
                    xps, self.critic_target, sample=True, discount=True)
                target_Q = target_Q - self.temp.alpha.detach() * first_log_p
                target_Q = rs + not_done * self.discount * target_Q
                target_Q = target_Q.detach()

        current_Q1, current_Q2 = [Q.squeeze(1) for Q in self.critic(xs, us)]
        assert current_Q1.size() == target_Q.size()
        assert current_Q2.size() == target_Q.size()
        Q_loss = F.mse_loss(current_Q1, target_Q) + \
            F.mse_loss(current_Q2, target_Q)

        logger.log('train_critic/Q_loss', Q_loss, step)
        current_Q = torch.min(current_Q1, current_Q2)
        logger.log('train_critic/value', current_Q.mean(), step)

        self.critic_opt.zero_grad()
        Q_loss.backward()
        logger.log('train_critic/Q_loss', Q_loss, step)
        self.critic_opt.step()

        self.critic.log(logger, step)