in xlm/model/memory/memory.py [0:0]
def forward(self, input):
"""
Read from the memory.
"""
# detach input
if self.query_detach_input:
input = input.detach()
# input dimensions
if self.input2d:
assert input.shape[1] == self.input_dim
n_images, _, height, width = input.shape
prefix_shape = (n_images, width, height)
else:
assert input.shape[-1] == self.input_dim
prefix_shape = input.shape[:-1]
# compute query / store it
bs = np.prod(prefix_shape)
input = F.dropout(input, p=self.input_dropout, training=self.training) # input shape
query = self.query_proj(input) # (bs * heads, k_dim)
query = F.dropout(query, p=self.query_dropout, training=self.training) # (bs * heads, k_dim)
assert query.shape == (bs * self.heads, self.k_dim)
# get indices
scores, indices = self.get_indices(query, self.knn) # (bs * heads, knn) ** 2
# optionally shuffle indices for different heads
if self.shuffle_indices:
indices = indices.view(bs, self.heads, -1).chunk(self.heads, 1)
indices = [p[idx] for p, idx in zip(self.head_permutations, indices)]
indices = torch.cat(indices, 1).view(bs * self.heads, -1)
# take indices modulo the memory size
if self.modulo_size != -1:
indices = indices % self.modulo_size
# re-scoring
if self.temperature != 1:
scores = scores / self.temperature # (bs * heads, knn)
if self.score_softmax:
scores = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs * heads, knn)
if self.score_subtract != '':
if self.score_subtract == 'min':
to_sub = scores.min(1, keepdim=True)[0] # (bs * heads, 1)
if self.score_subtract == 'mean':
to_sub = scores.mean(1, keepdim=True) # (bs * heads, 1)
if self.score_subtract == 'median':
to_sub = scores.median(1, keepdim=True)[0] # (bs * heads, 1)
scores = scores - to_sub # (bs * heads, knn)
if self.score_normalize:
scores = scores / scores.norm(p=1, dim=1, keepdim=True) # (bs * heads, knn)
# merge heads / knn (since we sum heads)
indices = indices.view(bs, self.heads * self.knn) # (bs, heads * knn)
scores = scores.view(bs, self.heads * self.knn) # (bs, heads * knn)
# weighted sum of values
# output = self.values(indices) * scores.unsqueeze(-1) # (bs * heads, knn, v_dim)
# output = output.sum(1) # (bs * heads, v_dim)
output = self.values(
indices,
per_sample_weights=scores.to(self.values.weight.data)
).to(scores) # (bs, v_dim)
output = F.dropout(output, p=self.value_dropout, training=self.training) # (bs, v_dim)
# reshape output
if self.input2d:
output = output.view(n_images, width, height, self.v_dim) # (n_images, width, height, v_dim)
output = output.transpose(1, 3) # (n_images, v_dim, height, width)
else:
if len(prefix_shape) >= 2:
output = output.view(prefix_shape + (self.v_dim,)) # (..., v_dim)
# store indices / scores (eval mode only - for usage statistics)
if not self.training and HashingMemory.EVAL_MEMORY:
self.last_indices = indices.view(bs, self.heads, self.knn).detach().cpu()
self.last_scores = scores.view(bs, self.heads, self.knn).detach().cpu().float()
return output