def expand_Q()

in svg/agent.py [0:0]


    def expand_Q(self, xs, critic, sample=True, discount=False):
        assert xs.dim() == 2
        n_batch = xs.size(0)
        us, log_p_us, pred_obs = self.dx.unroll_policy(
            xs, self.actor, sample=sample, detach_xt=self.actor_detach_rho)

        all_obs = torch.cat((xs.unsqueeze(0), pred_obs), dim=0)
        xu = torch.cat((all_obs, us), dim=2)
        dones = self.done(xu).sigmoid().squeeze(dim=2)
        not_dones = 1. - dones
        not_dones = utils.accum_prod(not_dones)
        last_not_dones = not_dones[-1]

        rewards = not_dones * self.rew(xu).squeeze(2)
        if critic is not None:
            with utils.eval_mode(critic):
                q1, q2 = critic(all_obs[-1], us[-1])
            q = torch.min(q1, q2).reshape(n_batch)
            rewards[-1] = last_not_dones * q

        assert rewards.size() == (self.horizon, n_batch)
        assert log_p_us.size() == (self.horizon, n_batch)
        rewards -= self.temp.alpha.detach() * log_p_us

        if discount:
            rewards *= self.discount_horizon.unsqueeze(1)

        total_rewards = rewards.sum(dim=0)

        first_log_p = log_p_us[0]
        total_log_p_us = log_p_us.sum(dim=0).squeeze()
        return total_rewards, first_log_p, total_log_p_us