def forward()

in xlm/model/memory/memory.py [0:0]


    def forward(self, input):
        """
        Read from the memory.
        """
        # detach input
        if self.query_detach_input:
            input = input.detach()

        # input dimensions
        if self.input2d:
            assert input.shape[1] == self.input_dim
            n_images, _, height, width = input.shape
            prefix_shape = (n_images, width, height)
        else:
            assert input.shape[-1] == self.input_dim
            prefix_shape = input.shape[:-1]

        # compute query / store it
        bs = np.prod(prefix_shape)
        input = F.dropout(input, p=self.input_dropout, training=self.training)    # input shape
        query = self.query_proj(input)                                            # (bs * heads, k_dim)
        query = F.dropout(query, p=self.query_dropout, training=self.training)    # (bs * heads, k_dim)
        assert query.shape == (bs * self.heads, self.k_dim)

        # get indices
        scores, indices = self.get_indices(query, self.knn)                       # (bs * heads, knn) ** 2

        # optionally shuffle indices for different heads
        if self.shuffle_indices:
            indices = indices.view(bs, self.heads, -1).chunk(self.heads, 1)
            indices = [p[idx] for p, idx in zip(self.head_permutations, indices)]
            indices = torch.cat(indices, 1).view(bs * self.heads, -1)

        # take indices modulo the memory size
        if self.modulo_size != -1:
            indices = indices % self.modulo_size

        # re-scoring
        if self.temperature != 1:
            scores = scores / self.temperature                                    # (bs * heads, knn)
        if self.score_softmax:
            scores = F.softmax(scores.float(), dim=-1).type_as(scores)            # (bs * heads, knn)
        if self.score_subtract != '':
            if self.score_subtract == 'min':
                to_sub = scores.min(1, keepdim=True)[0]                           # (bs * heads, 1)
            if self.score_subtract == 'mean':
                to_sub = scores.mean(1, keepdim=True)                             # (bs * heads, 1)
            if self.score_subtract == 'median':
                to_sub = scores.median(1, keepdim=True)[0]                        # (bs * heads, 1)
            scores = scores - to_sub                                              # (bs * heads, knn)
        if self.score_normalize:
            scores = scores / scores.norm(p=1, dim=1, keepdim=True)               # (bs * heads, knn)

        # merge heads / knn (since we sum heads)
        indices = indices.view(bs, self.heads * self.knn)                         # (bs, heads * knn)
        scores = scores.view(bs, self.heads * self.knn)                           # (bs, heads * knn)

        # weighted sum of values
        # output = self.values(indices) * scores.unsqueeze(-1)                    # (bs * heads, knn, v_dim)
        # output = output.sum(1)                                                  # (bs * heads, v_dim)
        output = self.values(
            indices,
            per_sample_weights=scores.to(self.values.weight.data)
        ).to(scores)                                                              # (bs, v_dim)
        output = F.dropout(output, p=self.value_dropout, training=self.training)  # (bs, v_dim)

        # reshape output
        if self.input2d:
            output = output.view(n_images, width, height, self.v_dim)             # (n_images, width, height, v_dim)
            output = output.transpose(1, 3)                                       # (n_images, v_dim, height, width)
        else:
            if len(prefix_shape) >= 2:
                output = output.view(prefix_shape + (self.v_dim,))                # (..., v_dim)

        # store indices / scores (eval mode only - for usage statistics)
        if not self.training and HashingMemory.EVAL_MEMORY:
            self.last_indices = indices.view(bs, self.heads, self.knn).detach().cpu()
            self.last_scores = scores.view(bs, self.heads, self.knn).detach().cpu().float()

        return output