def check_params()

in XLM/src/model/memory/memory.py [0:0]


    def check_params(params):
        """
        Check and initialize memory parameters.
        """
        # memory
        assert params.mem_implementation in ['flat', 'pq_default', 'pq_fast']
        params.mem_product_quantization = params.mem_implementation != 'flat'

        # optimization
        assert params.mem_grouped_conv is False or params.mem_multi_query_net
        params.mem_values_optimizer = params.optimizer if params.mem_values_optimizer == '' else params.mem_values_optimizer
        params.mem_values_optimizer = params.mem_values_optimizer.replace('adam', 'sparseadam') if params.mem_sparse else params.mem_values_optimizer

        # even number of key dimensions for product quantization
        assert params.mem_k_dim >= 2
        assert params.mem_product_quantization is False or params.mem_k_dim % 2 == 0

        # memory type
        assert params.mem_keys_type in ['binary', 'gaussian', 'uniform']

        # number of indices
        if params.mem_keys_type == 'binary':
            assert params.mem_keys_normalized_init is False
            assert 1 << params.mem_k_dim == params.mem_n_keys
        if params.mem_product_quantization:
            params.n_indices = params.mem_n_keys ** 2
        else:
            params.n_indices = params.mem_n_keys

        # actual memory size
        if params.mem_modulo_size == -1:
            params.mem_size = params.n_indices
        else:
            assert 1 <= params.mem_modulo_size < params.n_indices
            params.mem_size = params.mem_modulo_size

        # different keys / different query MLP / shuffle hidden dimensions when no query network
        assert not params.mem_use_different_keys or params.mem_keys_type in ['gaussian', 'uniform']
        assert not params.mem_use_different_keys or params.mem_heads >= 2 or params.mem_product_quantization
        assert not params.mem_multi_query_net or params.mem_heads >= 2 or params.mem_product_quantization
        assert not params.mem_multi_query_net or params.mem_query_layer_sizes not in ['', '0,0']
        assert not params.mem_shuffle_query or params.mem_heads > 1 and params.mem_query_layer_sizes == ''

        # query network
        if params.mem_query_layer_sizes == '':
            assert params.mem_heads == 1 or params.mem_use_different_keys or params.mem_shuffle_query
        else:
            s = [int(x) for x in filter(None, params.mem_query_layer_sizes.split(','))]
            assert len(s) >= 2 and s[0] == s[-1] == 0
            params.mem_query_layer_sizes = s
            assert not params.mem_query_residual or params.mem_input2d

        # convolutional query network kernel sizes
        if params.mem_query_kernel_sizes == '':
            assert not params.mem_input2d or params.mem_query_layer_sizes == ''
        else:
            assert params.mem_input2d
            s = [int(x) for x in filter(None, params.mem_query_kernel_sizes.split(','))]
            params.mem_query_kernel_sizes = s
            assert all(ks % 2 == 1 for ks in s)
            assert len(params.mem_query_kernel_sizes) == len(params.mem_query_layer_sizes) - 1 >= 1

        # scoring
        assert params.mem_score_subtract in ['', 'min', 'mean', 'median']
        assert params.mem_score_subtract == '' or params.mem_knn >= 2
        assert not (params.mem_score_normalize and params.mem_score_softmax and params.mem_score_subtract == '')

        # dropout
        assert 0 <= params.mem_input_dropout < 1
        assert 0 <= params.mem_query_dropout < 1
        assert 0 <= params.mem_value_dropout < 1

        # query batchnorm
        if params.mem_query_batchnorm:
            logger.warning("WARNING: if you use batch normalization, be sure that you use batches of sentences with the same size at training time. Otherwise, the padding token will result in incorrect mean/variance estimations in the BatchNorm layer.")