in xlm/model/memory/memory.py [0:0]
def register_args(parser):
"""
Register memory parameters
"""
# memory implementation
parser.add_argument("--mem_implementation", type=str, default="pq_fast",
help="Memory implementation (flat, pq_default, pq_fast)")
# optimization
parser.add_argument("--mem_grouped_conv", type=bool_flag, default=False,
help="Use grouped convolutions in the query network")
parser.add_argument("--mem_values_optimizer", type=str, default="adam,lr=0.001",
help="Memory values optimizer ("" for the same optimizer as the rest of the model)")
parser.add_argument("--mem_sparse", type=bool_flag, default=False,
help="Perform sparse updates for the values")
# global parameters
parser.add_argument("--mem_input2d", type=bool_flag, default=False,
help="Convolutional query network")
parser.add_argument("--mem_k_dim", type=int, default=256,
help="Memory keys dimension")
parser.add_argument("--mem_v_dim", type=int, default=-1,
help="Memory values dimension (-1 for automatic output dimension)")
parser.add_argument("--mem_heads", type=int, default=4,
help="Number of memory reading heads")
parser.add_argument("--mem_knn", type=int, default=32,
help="Number of memory slots to read / update - k-NN to the query")
parser.add_argument("--mem_share_values", type=bool_flag, default=False,
help="Share values across memories")
parser.add_argument("--mem_shuffle_indices", type=bool_flag, default=False,
help="Shuffle indices for different heads")
parser.add_argument("--mem_shuffle_query", type=bool_flag, default=False,
help="Shuffle query dimensions (when the query network is the identity and there are multiple heads)")
parser.add_argument("--mem_modulo_size", type=int, default=-1,
help="Effective memory size: indices are taken modulo this parameter. -1 to disable.")
# keys
parser.add_argument("--mem_keys_type", type=str, default="uniform",
help="Memory keys type (binary,gaussian,uniform)")
parser.add_argument("--mem_n_keys", type=int, default=512,
help="Number of keys")
parser.add_argument("--mem_keys_normalized_init", type=bool_flag, default=False,
help="Normalize keys at initialization")
parser.add_argument("--mem_keys_learn", type=bool_flag, default=True,
help="Learn keys")
parser.add_argument("--mem_use_different_keys", type=bool_flag, default=True,
help="Use different keys for each head / product quantization")
# queries
parser.add_argument("--mem_query_detach_input", type=bool_flag, default=False,
help="Detach input")
parser.add_argument("--mem_query_layer_sizes", type=str, default="0,0",
help="Query MLP layer sizes ('', '0,0', '0,512,0')")
parser.add_argument("--mem_query_kernel_sizes", type=str, default="",
help="Query MLP kernel sizes (2D inputs only)")
parser.add_argument("--mem_query_bias", type=bool_flag, default=True,
help="Query MLP bias")
parser.add_argument("--mem_query_batchnorm", type=bool_flag, default=False,
help="Query MLP batch norm")
parser.add_argument("--mem_query_net_learn", type=bool_flag, default=True,
help="Query MLP learn")
parser.add_argument("--mem_query_residual", type=bool_flag, default=False,
help="Use a bottleneck with a residual layer in the query MLP")
parser.add_argument("--mem_multi_query_net", type=bool_flag, default=False,
help="Use multiple query MLP (one for each head)")
# values initialization
parser.add_argument("--mem_value_zero_init", type=bool_flag, default=False,
help="Initialize values with zeros")
# scoring
parser.add_argument("--mem_normalize_query", type=bool_flag, default=False,
help="Normalize queries")
parser.add_argument("--mem_temperature", type=float, default=1,
help="Divide scores by a temperature")
parser.add_argument("--mem_score_softmax", type=bool_flag, default=True,
help="Apply softmax on scores")
parser.add_argument("--mem_score_subtract", type=str, default="",
help="Subtract scores ('', min, mean, median)")
parser.add_argument("--mem_score_normalize", type=bool_flag, default=False,
help="L1 normalization of the scores")
# dropout
parser.add_argument("--mem_input_dropout", type=float, default=0,
help="Input dropout")
parser.add_argument("--mem_query_dropout", type=float, default=0,
help="Query dropout")
parser.add_argument("--mem_value_dropout", type=float, default=0,
help="Value dropout")