def update_critic_mve()

in svg/agent.py [0:0]


    def update_critic_mve(self, first_xs, first_us, first_rs, next_xs, first_not_dones, logger, step):
        """ MVE critic loss from Feinberg et al (2015) """
        assert first_xs.dim() == 2
        assert first_us.dim() == 2
        assert first_rs.dim() == 2
        assert next_xs.dim() == 2
        assert first_not_dones.dim() == 2
        n_batch = next_xs.size(0)

        # unroll policy, concatenate obs and actions
        pred_us, log_p_us, pred_xs = self.dx.unroll_policy(
            next_xs, self.actor, sample=True, detach_xt=self.actor_detach_rho)
        all_obs = torch.cat((first_xs.unsqueeze(0), next_xs.unsqueeze(0), pred_xs))
        all_us = torch.cat([first_us.unsqueeze(0), pred_us])
        xu = torch.cat([all_obs, all_us], dim=2)
        horizon_len = all_obs.size(0) - 1  # H

        # get immediate rewards
        pred_rs = self.rew(xu[1:-1])  # t from 0 to H - 1
        rewards = torch.cat([first_rs.unsqueeze(0), pred_rs]).squeeze(2)
        rewards = rewards.unsqueeze(1).expand(-1, horizon_len, -1)
        log_p_us = log_p_us.unsqueeze(1).expand(-1, horizon_len, -1)

        # get not dones factor matrix, rows --> t, cols --> k
        first_not_dones = first_not_dones.unsqueeze(0)
        init_not_dones = torch.ones_like(first_not_dones)  # we know the first states are not terminal
        pred_not_dones = 1. - self.done(xu[2:]).sigmoid()  # t from 1 to H
        not_dones = torch.cat([init_not_dones, first_not_dones, pred_not_dones]).squeeze(2)
        not_dones = not_dones.unsqueeze(1).repeat(1, horizon_len, 1)
        triu_rows, triu_cols = torch.triu_indices(row=horizon_len + 1, col=horizon_len, offset=1, device=not_dones.device)
        not_dones[triu_rows, triu_cols, :] = 1.
        not_dones = not_dones.cumprod(dim=0).detach()

        # get lower-triangular reward discount factor matrix
        discount = torch.tensor(self.discount, device=rewards.device)
        discount_exps = torch.stack([torch.arange(-i, -i + horizon_len) for i in range(horizon_len)], dim=1)
        r_discounts = discount ** discount_exps.to(rewards.device)
        r_discounts = r_discounts.tril().unsqueeze(-1)

        # get discounted sums of soft rewards (t from -1 to H - 1 (k from t to H - 1))
        alpha = self.temp.alpha.detach()
        soft_rewards = (not_dones[:-1] * rewards) - (discount * alpha * not_dones[1:] * log_p_us)
        soft_rewards = (r_discounts * soft_rewards).sum(0)

        # get target q-values, final critic targets
        target_q1, target_q2 = self.critic_target(all_obs[-1], all_us[-1])
        target_qs = torch.min(target_q1, target_q2).squeeze(-1).expand(horizon_len, -1)
        q_discounts = discount ** torch.arange(horizon_len, 0, step=-1).to(target_qs.device)
        target_qs = target_qs * (not_dones[-1] * q_discounts.unsqueeze(-1))
        critic_targets = (soft_rewards + target_qs).detach()

        # get predicted q-values
        with utils.eval_mode(self.critic):
            q1, q2 = self.critic(all_obs[:-1].flatten(end_dim=-2),
                                 all_us[:-1].flatten(end_dim=-2))
            q1, q2 = q1.reshape(horizon_len, n_batch), q2.reshape(horizon_len, n_batch)
        assert q1.size() == critic_targets.size()
        assert q2.size() == critic_targets.size()

        # update critics
        q1_loss = (not_dones[:-1, 0] * (q1 - critic_targets).pow(2)).mean()
        q2_loss = (not_dones[:-1, 0] * (q2 - critic_targets).pow(2)).mean()
        Q_loss = q1_loss + q2_loss

        logger.log('train_critic/Q_loss', Q_loss, step)
        current_Q = torch.min(q1, q2)
        logger.log('train_critic/value', current_Q.mean(), step)

        self.critic_opt.zero_grad()
        Q_loss.backward()
        logger.log('train_critic/Q_loss', Q_loss, step)
        self.critic_opt.step()

        self.critic.log(logger, step)