def forward()

in src/model.py [0:0]


    def forward(self, input, query):
        if query.dim() == 3:
            # 3-dimensional query means we are precomputing attention
            seq_len = query.size(0)
            query = query.view(seq_len, -1, 1, self.query_dim)
        else:
            query = query.view(-1, 1, self.query_dim)

        input_dim = input.dim()
        if input_dim == 3:
            input = input.view(-1, self.md_group_size, self.md_dim)
        elif input_dim == 4:
            input = input.view(seq_len, -1, self.md_group_size, self.md_dim)
        else:
            raise Exception(f"Invalid number of input dimension: {input_dim}")

        if self.use_null_token:
            if input_dim == 3:
                zeros = self.zeros.repeat(input.size(0), 1, 1).to(device)
                input = torch.cat((input, zeros), dim=1)
            else:
                zeros = self.zeros.repeat(input.size(0), input.size(1), 1, 1).to(device)
                test = self.zeros.repeat(input.size(0), 1, 1).to(device)
                input = torch.cat((input, zeros), dim=2)

        scores = torch.matmul(query, self.W)
        scores = torch.matmul(scores, input.transpose(-1,-2))

        alphas = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(alphas, input).squeeze()
        return context