in src/model.py [0:0]
def forward(self, input, query):
if query.dim() == 3:
# 3-dimensional query means we are precomputing attention
seq_len = query.size(0)
query = query.view(seq_len, -1, 1, self.query_dim)
else:
query = query.view(-1, 1, self.query_dim)
input_dim = input.dim()
if input_dim == 3:
input = input.view(-1, self.md_group_size, self.md_dim)
elif input_dim == 4:
input = input.view(seq_len, -1, self.md_group_size, self.md_dim)
else:
raise Exception(f"Invalid number of input dimension: {input_dim}")
if self.use_null_token:
if input_dim == 3:
zeros = self.zeros.repeat(input.size(0), 1, 1).to(device)
input = torch.cat((input, zeros), dim=1)
else:
zeros = self.zeros.repeat(input.size(0), input.size(1), 1, 1).to(device)
test = self.zeros.repeat(input.size(0), 1, 1).to(device)
input = torch.cat((input, zeros), dim=2)
hidden_keys = self.key_projection(input)
hidden_query = self.query_projection(query)
scores = self.energy_projection(torch.tanh(hidden_query + hidden_keys))
alphas = nn.Softmax(dim=-1)(scores).transpose(-1,-2)
context = torch.matmul(alphas, input).squeeze()
return context