in src/model.py [0:0]
def forward(self, input, query):
if query.dim() == 3:
# 3-dimensional query means we are precomputing attention
seq_len = query.size(0)
query = query.view(seq_len, -1, 1, self.query_dim)
else:
query = query.view(-1, 1, self.query_dim)
input_dim = input.dim()
if input_dim == 3:
input = input.view(-1, self.md_group_size, self.md_dim)
elif input_dim == 4:
input = input.view(seq_len, -1, self.md_group_size, self.md_dim)
else:
raise Exception(f"Invalid number of input dimension: {input_dim}")
if self.use_null_token:
if input_dim == 3:
zeros = self.zeros.repeat(input.size(0), 1, 1).to(device)
input = torch.cat((input, zeros), dim=1)
else:
zeros = self.zeros.repeat(input.size(0), input.size(1), 1, 1).to(device)
test = self.zeros.repeat(input.size(0), 1, 1).to(device)
input = torch.cat((input, zeros), dim=2)
scores = torch.matmul(query, self.W)
scores = torch.matmul(scores, input.transpose(-1,-2))
alphas = nn.Softmax(dim=-1)(scores)
context = torch.matmul(alphas, input).squeeze()
return context