def __init__()

in XLM/src/model/memory/memory.py [0:0]


    def __init__(self, input_dim, output_dim, params):

        super().__init__()
        self.id = next(self._ids)

        # global parameters
        self.input2d = params.mem_input2d
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.size = params.mem_size
        self.modulo_size = params.mem_modulo_size
        self.n_indices = params.n_indices
        self.k_dim = params.mem_k_dim
        self.v_dim = params.mem_v_dim if params.mem_v_dim > 0 else output_dim
        self.heads = params.mem_heads
        self.knn = params.mem_knn
        self.shuffle_indices = params.mem_shuffle_indices
        self.keys_normalized_init = params.mem_keys_normalized_init
        self.product_quantization = params.mem_product_quantization
        assert self.modulo_size == -1 and self.size == self.n_indices or self.n_indices > self.size == self.modulo_size >= 1

        # keys / queries
        self.keys_type = params.mem_keys_type
        self.learn_keys = params.mem_keys_learn
        self.use_different_keys = params.mem_use_different_keys
        self.query_detach_input = params.mem_query_detach_input
        self.query_net_learn = params.mem_query_net_learn
        self.multi_query_net = params.mem_multi_query_net
        self.shuffle_query = params.mem_shuffle_query
        assert self.use_different_keys is False or self.keys_type in ['gaussian', 'uniform']
        assert self.use_different_keys is False or self.heads >= 2 or self.product_quantization
        assert self.multi_query_net is False or self.heads >= 2 or self.product_quantization
        assert self.shuffle_query is False or self.heads > 1 and params.mem_query_layer_sizes == ''
        assert self.shuffle_query is False or self.input_dim % (2 ** self.heads) == 0

        # scoring / re-scoring
        self.normalize_query = params.mem_normalize_query
        self.temperature = params.mem_temperature
        self.score_softmax = params.mem_score_softmax
        self.score_subtract = params.mem_score_subtract
        self.score_normalize = params.mem_score_normalize
        assert self.score_subtract in ['', 'min', 'mean', 'median']
        assert self.score_subtract == '' or self.knn >= 2
        assert not (self.score_normalize and self.score_softmax and self.score_subtract == '')

        # dropout
        self.input_dropout = params.mem_input_dropout
        self.query_dropout = params.mem_query_dropout
        self.value_dropout = params.mem_value_dropout

        # initialize keys
        self.init_keys()

        # self.values = nn.Embedding(self.size, self.v_dim, sparse=params.mem_sparse)
        self.values = nn.EmbeddingBag(self.size, self.v_dim, mode='sum', sparse=params.mem_sparse)

        # optionally use the same values for all memories
        if params.mem_share_values:
            if HashingMemory.VALUES is None:
                HashingMemory.VALUES = self.values.weight
            else:
                self.values.weight = HashingMemory.VALUES

        # values initialization
        if params.mem_value_zero_init:
            nn.init.zeros_(self.values.weight)
        else:
            nn.init.normal_(self.values.weight, mean=0, std=self.v_dim ** -0.5)

        # no query network
        if len(params.mem_query_layer_sizes) == 0:
            assert self.heads == 1 or self.use_different_keys or self.shuffle_query
            assert self.input_dim == self.k_dim
            self.query_proj = QueryIdentity(self.input_dim, self.heads, self.shuffle_query)

        # query network
        if len(params.mem_query_layer_sizes) > 0:
            assert not self.shuffle_query

            # layer sizes / number of features
            l_sizes = list(params.mem_query_layer_sizes)
            assert len(l_sizes) >= 2 and l_sizes[0] == l_sizes[-1] == 0
            l_sizes[0] = self.input_dim
            l_sizes[-1] = (self.k_dim // 2) if self.multi_query_net else (self.heads * self.k_dim)

            # convolutional or feedforward
            if self.input2d:
                self.query_proj = QueryConv(
                    self.input_dim, self.heads, self.k_dim, self.product_quantization,
                    self.multi_query_net, l_sizes, params.mem_query_kernel_sizes,
                    bias=params.mem_query_bias, batchnorm=params.mem_query_batchnorm,
                    grouped_conv=params.mem_grouped_conv
                )
            else:
                assert params.mem_query_kernel_sizes == ''
                assert not params.mem_query_residual
                self.query_proj = QueryMLP(
                    self.input_dim, self.heads, self.k_dim, self.product_quantization,
                    self.multi_query_net, l_sizes,
                    bias=params.mem_query_bias, batchnorm=params.mem_query_batchnorm,
                    grouped_conv=params.mem_grouped_conv
                )

        # shuffle indices for different heads
        if self.shuffle_indices:
            head_permutations = [torch.randperm(self.n_indices).unsqueeze(0) for i in range(self.heads)]
            self.register_buffer('head_permutations', torch.cat(head_permutations, 0))

        # do not learn the query network
        if self.query_net_learn is False:
            for p in self.query_proj.parameters():
                p.requires_grad = False