in svg/agent.py [0:0]
def expand_Q(self, xs, critic, sample=True, discount=False):
assert xs.dim() == 2
n_batch = xs.size(0)
us, log_p_us, pred_obs = self.dx.unroll_policy(
xs, self.actor, sample=sample, detach_xt=self.actor_detach_rho)
all_obs = torch.cat((xs.unsqueeze(0), pred_obs), dim=0)
xu = torch.cat((all_obs, us), dim=2)
dones = self.done(xu).sigmoid().squeeze(dim=2)
not_dones = 1. - dones
not_dones = utils.accum_prod(not_dones)
last_not_dones = not_dones[-1]
rewards = not_dones * self.rew(xu).squeeze(2)
if critic is not None:
with utils.eval_mode(critic):
q1, q2 = critic(all_obs[-1], us[-1])
q = torch.min(q1, q2).reshape(n_batch)
rewards[-1] = last_not_dones * q
assert rewards.size() == (self.horizon, n_batch)
assert log_p_us.size() == (self.horizon, n_batch)
rewards -= self.temp.alpha.detach() * log_p_us
if discount:
rewards *= self.discount_horizon.unsqueeze(1)
total_rewards = rewards.sum(dim=0)
first_log_p = log_p_us[0]
total_log_p_us = log_p_us.sum(dim=0).squeeze()
return total_rewards, first_log_p, total_log_p_us