in lib/xf.py [0:0]
def forward(self, x):
tl = sum(self.seqlens)
x, undo = misc.reshape_undo(x, "..., z*tl, e", "..., z, tl, e", tl=tl)
x = list(th.split(x, self.seqlens, dim=-2))
new_x = []
for x, mod in misc.safezip(x, self.mods):
x, this_undo = misc.reshape_undo(x, "..., z, l, e", "..., z*l, e")
x = mod(x)
x = this_undo(x)
new_x.append(x)
x = th.cat(new_x, dim=-2)
x = undo(x)
return x