in codes/models.py [0:0]
def forward(self, x):
b, _, n = x.shape
remainder = n % self.pool_size
needs_padding = remainder > 0
if needs_padding:
x = F.pad(x, (0, remainder), value = 0)
mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
mask = F.pad(mask, (0, remainder), value = True)
x = self.pool_fn(x)
logits = self.to_attn_logits(x)
if needs_padding:
mask_value = -torch.finfo(logits.dtype).max
logits = logits.masked_fill(self.pool_fn(mask), mask_value)
attn = logits.softmax(dim = -1)
return (x * attn).sum(dim = -1)