in svg/dx.py [0:0]
def unroll_policy(self, init_x, policy, sample=True,
last_u=True, detach_xt=False):
assert init_x.dim() == 2
n_batch = init_x.size(0)
if self.freeze_dims is not None:
obs_frozen = init_x[:, self.freeze_dims]
if self.rec_num_layers > 0:
h = self.init_hidden_state(init_x)
pred_xs = []
us = []
log_p_us = []
xt = init_x
for t in range(self.horizon-1):
policy_kwargs = {}
if sample:
_, ut, log_p_ut = policy(xt, **policy_kwargs)
else:
ut, _, log_p_ut = policy(xt, **policy_kwargs)
us.append(ut)
log_p_us.append(log_p_ut)
if detach_xt:
xt = xt.detach()
xut = torch.cat((xt, ut), dim=1)
xu_emb = self.xu_enc(xut).unsqueeze(0)
if self.rec_num_layers > 0:
xtp1_emb, h = self.rec(xu_emb, h)
else:
xtp1_emb = xu_emb
xtp1 = xt + self.x_dec(xtp1_emb.squeeze(0))
if self.freeze_dims is not None:
xtp1[:,self.freeze_dims] = obs_frozen
pred_xs.append(xtp1)
xt = xtp1
if last_u:
policy_kwargs = {}
if sample:
_, ut, log_p_ut = policy(xt, **policy_kwargs)
else:
ut, _, log_p_ut = policy(xt, **policy_kwargs)
us.append(ut)
log_p_us.append(log_p_ut)
us = torch.stack(us)
log_p_us = torch.stack(log_p_us).squeeze(2)
if self.horizon <= 1:
pred_xs = torch.empty(0, n_batch, self.obs_dim).to(init_x.device)
else:
pred_xs = torch.stack(pred_xs)
return us, log_p_us, pred_xs