src/model.py [34:57]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def forward(self, input, query):
        if query.dim() == 3:
            # 3-dimensional query means we are precomputing attention
            seq_len = query.size(0)
            query = query.view(seq_len, -1, 1, self.query_dim)
        else:
            query = query.view(-1, 1, self.query_dim)

        input_dim = input.dim()
        if input_dim == 3:
            input = input.view(-1, self.md_group_size, self.md_dim)
        elif input_dim == 4:
            input = input.view(seq_len, -1, self.md_group_size, self.md_dim)
        else:
            raise Exception(f"Invalid number of input dimension: {input_dim}")

        if self.use_null_token:
            if input_dim == 3:
                zeros = self.zeros.repeat(input.size(0), 1, 1).to(device)
                input = torch.cat((input, zeros), dim=1)
            else:
                zeros = self.zeros.repeat(input.size(0), input.size(1), 1, 1).to(device)
                test = self.zeros.repeat(input.size(0), 1, 1).to(device)
                input = torch.cat((input, zeros), dim=2)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/model.py [87:110]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def forward(self, input, query):
        if query.dim() == 3:
            # 3-dimensional query means we are precomputing attention
            seq_len = query.size(0)
            query = query.view(seq_len, -1, 1, self.query_dim)
        else:
            query = query.view(-1, 1, self.query_dim)

        input_dim = input.dim()
        if input_dim == 3:
            input = input.view(-1, self.md_group_size, self.md_dim)
        elif input_dim == 4:
            input = input.view(seq_len, -1, self.md_group_size, self.md_dim)
        else:
            raise Exception(f"Invalid number of input dimension: {input_dim}")

        if self.use_null_token:
            if input_dim == 3:
                zeros = self.zeros.repeat(input.size(0), 1, 1).to(device)
                input = torch.cat((input, zeros), dim=1)
            else:
                zeros = self.zeros.repeat(input.size(0), input.size(1), 1, 1).to(device)
                test = self.zeros.repeat(input.size(0), 1, 1).to(device)
                input = torch.cat((input, zeros), dim=2)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



