in svg/dx.py [0:0]
def unroll(self, x, us, detach_xt=False):
assert x.dim() == 2
assert us.dim() == 3
n_batch = x.size(0)
assert us.size(1) == n_batch
if self.freeze_dims is not None:
obs_frozen = x[:, self.freeze_dims]
if self.rec_num_layers > 0:
h = self.init_hidden_state(x)
pred_xs = []
xt = x
for t in range(us.size(0)):
ut = us[t]
if detach_xt:
xt = xt.detach()
xut = torch.cat((xt, ut), dim=1)
xu_emb = self.xu_enc(xut).unsqueeze(0)
if self.rec_num_layers > 0:
xtp1_emb, h = self.rec(xu_emb, h)
else:
xtp1_emb = xu_emb
xtp1 = xt + self.x_dec(xtp1_emb.squeeze(0))
if self.freeze_dims is not None:
xtp1[:,self.freeze_dims] = obs_frozen
pred_xs.append(xtp1)
xt = xtp1
pred_xs = torch.stack(pred_xs)
return pred_xs