in lib/xf.py [0:0]
def _preproc(self, x, name, Q_t=None, Q_pad=None):
x, undo = misc.reshape_undo(x, "b, t*stride, e", "b, 1, t, stride*e", stride=self.stride)
if name == "Q":
Q_pad = _required_padding(x.shape[2], self.maxlen)
original_t = x.shape[2]
x = F.pad(x, (0, 0, 0, Q_pad), value=SENTINEL)
undo = misc.compose_undo(undo, lambda x: x[:, :, :original_t])
if name == "Q":
Q_t = x.shape[2]
assert Q_t % self.maxlen == 0, f"{Q_t} % {self.maxlen} != 0"
else:
required_len = Q_t + self.maxlen
if x.shape[2] < required_len:
x = F.pad(x, (0, 0, required_len - x.shape[2], 0), value=SENTINEL)
assert x.shape[2] >= required_len
back = x[:, :, -Q_t - self.maxlen : -self.maxlen]
front = x[:, :, -Q_t:]
x = th.cat([back, front], dim=1)
_, _, t, _ = x.shape
assert t == Q_t, f"{t} != {Q_t}"
x, undo = misc.reshape_undo(
x,
"b, pad_shift, t*maxlen, stride*h*q",
"b, pad_shift, t, maxlen, stride, h, q",
maxlen=self.maxlen,
h=self.h,
stride=self.stride,
undo=undo,
)
x, undo = misc.transpose_undo(x, "bptmshq", "bthspmq", undo=undo)
x, undo = misc.reshape_undo(
x,
"b, t, h, stride, pad_shift, maxlen, q",
"b*t*h*stride, pad_shift*maxlen, q",
undo=undo,
)
if name == "Q":
return x, undo, Q_t, Q_pad
else:
return x