in XLM/src/model/memory/memory.py [0:0]
def __init__(self, input_dim, output_dim, params):
super().__init__()
self.id = next(self._ids)
# global parameters
self.input2d = params.mem_input2d
self.input_dim = input_dim
self.output_dim = output_dim
self.size = params.mem_size
self.modulo_size = params.mem_modulo_size
self.n_indices = params.n_indices
self.k_dim = params.mem_k_dim
self.v_dim = params.mem_v_dim if params.mem_v_dim > 0 else output_dim
self.heads = params.mem_heads
self.knn = params.mem_knn
self.shuffle_indices = params.mem_shuffle_indices
self.keys_normalized_init = params.mem_keys_normalized_init
self.product_quantization = params.mem_product_quantization
assert self.modulo_size == -1 and self.size == self.n_indices or self.n_indices > self.size == self.modulo_size >= 1
# keys / queries
self.keys_type = params.mem_keys_type
self.learn_keys = params.mem_keys_learn
self.use_different_keys = params.mem_use_different_keys
self.query_detach_input = params.mem_query_detach_input
self.query_net_learn = params.mem_query_net_learn
self.multi_query_net = params.mem_multi_query_net
self.shuffle_query = params.mem_shuffle_query
assert self.use_different_keys is False or self.keys_type in ['gaussian', 'uniform']
assert self.use_different_keys is False or self.heads >= 2 or self.product_quantization
assert self.multi_query_net is False or self.heads >= 2 or self.product_quantization
assert self.shuffle_query is False or self.heads > 1 and params.mem_query_layer_sizes == ''
assert self.shuffle_query is False or self.input_dim % (2 ** self.heads) == 0
# scoring / re-scoring
self.normalize_query = params.mem_normalize_query
self.temperature = params.mem_temperature
self.score_softmax = params.mem_score_softmax
self.score_subtract = params.mem_score_subtract
self.score_normalize = params.mem_score_normalize
assert self.score_subtract in ['', 'min', 'mean', 'median']
assert self.score_subtract == '' or self.knn >= 2
assert not (self.score_normalize and self.score_softmax and self.score_subtract == '')
# dropout
self.input_dropout = params.mem_input_dropout
self.query_dropout = params.mem_query_dropout
self.value_dropout = params.mem_value_dropout
# initialize keys
self.init_keys()
# self.values = nn.Embedding(self.size, self.v_dim, sparse=params.mem_sparse)
self.values = nn.EmbeddingBag(self.size, self.v_dim, mode='sum', sparse=params.mem_sparse)
# optionally use the same values for all memories
if params.mem_share_values:
if HashingMemory.VALUES is None:
HashingMemory.VALUES = self.values.weight
else:
self.values.weight = HashingMemory.VALUES
# values initialization
if params.mem_value_zero_init:
nn.init.zeros_(self.values.weight)
else:
nn.init.normal_(self.values.weight, mean=0, std=self.v_dim ** -0.5)
# no query network
if len(params.mem_query_layer_sizes) == 0:
assert self.heads == 1 or self.use_different_keys or self.shuffle_query
assert self.input_dim == self.k_dim
self.query_proj = QueryIdentity(self.input_dim, self.heads, self.shuffle_query)
# query network
if len(params.mem_query_layer_sizes) > 0:
assert not self.shuffle_query
# layer sizes / number of features
l_sizes = list(params.mem_query_layer_sizes)
assert len(l_sizes) >= 2 and l_sizes[0] == l_sizes[-1] == 0
l_sizes[0] = self.input_dim
l_sizes[-1] = (self.k_dim // 2) if self.multi_query_net else (self.heads * self.k_dim)
# convolutional or feedforward
if self.input2d:
self.query_proj = QueryConv(
self.input_dim, self.heads, self.k_dim, self.product_quantization,
self.multi_query_net, l_sizes, params.mem_query_kernel_sizes,
bias=params.mem_query_bias, batchnorm=params.mem_query_batchnorm,
grouped_conv=params.mem_grouped_conv
)
else:
assert params.mem_query_kernel_sizes == ''
assert not params.mem_query_residual
self.query_proj = QueryMLP(
self.input_dim, self.heads, self.k_dim, self.product_quantization,
self.multi_query_net, l_sizes,
bias=params.mem_query_bias, batchnorm=params.mem_query_batchnorm,
grouped_conv=params.mem_grouped_conv
)
# shuffle indices for different heads
if self.shuffle_indices:
head_permutations = [torch.randperm(self.n_indices).unsqueeze(0) for i in range(self.heads)]
self.register_buffer('head_permutations', torch.cat(head_permutations, 0))
# do not learn the query network
if self.query_net_learn is False:
for p in self.query_proj.parameters():
p.requires_grad = False