in XLM/src/model/__init__.py [0:0]
def check_model_params(params):
"""
Check models parameters.
"""
# masked language modeling task parameters
assert params.bptt >= 1
assert 0 <= params.word_pred < 1
assert 0 <= params.sample_alpha < 1
s = params.word_mask_keep_rand.split(',')
assert len(s) == 3
s = [float(x) for x in s]
assert all([0 <= x <= 1 for x in s]) and sum(s) == 1
params.word_mask = s[0]
params.word_keep = s[1]
params.word_rand = s[2]
# input sentence noise for DAE
if len(params.ae_steps) == 0:
assert params.word_shuffle == 0
assert params.word_dropout == 0
assert params.word_blank == 0
else:
assert params.word_shuffle == 0 or params.word_shuffle > 1
assert 0 <= params.word_dropout < 1
assert 0 <= params.word_blank < 1
# model dimensions
assert params.emb_dim % params.n_heads == 0
# share input and output embeddings
assert params.share_inout_emb is False or params.asm is False
# adaptive softmax
if params.asm:
assert params.asm_div_value > 1
s = params.asm_cutoffs.split(',')
assert all([x.isdigit() for x in s])
params.asm_cutoffs = [int(x) for x in s]
assert params.max_vocab == -1 or params.asm_cutoffs[-1] < params.max_vocab
# memory
if params.use_memory:
HashingMemory.check_params(params)
s_enc = [x for x in params.mem_enc_positions.split(',') if x != '']
s_dec = [x for x in params.mem_dec_positions.split(',') if x != '']
assert len(s_enc) == len(set(s_enc))
assert len(s_dec) == len(set(s_dec))
assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_enc)
assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_dec)
params.mem_enc_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_enc]
params.mem_dec_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_dec]
assert len(params.mem_enc_positions) + len(params.mem_dec_positions) > 0
assert len(params.mem_enc_positions) == 0 or 0 <= min([x[0] for x in params.mem_enc_positions]) <= max([x[0] for x in params.mem_enc_positions]) <= params.n_layers - 1
assert len(params.mem_dec_positions) == 0 or 0 <= min([x[0] for x in params.mem_dec_positions]) <= max([x[0] for x in params.mem_dec_positions]) <= params.n_layers - 1
# reload pretrained word embeddings
if params.reload_emb != '':
assert os.path.isfile(params.reload_emb)
# reload a pretrained model
if params.reload_model != '':
if params.encoder_only:
assert os.path.isfile(params.reload_model)
else:
s = params.reload_model.split(',')
assert len(s) == 2
assert all([x == '' or os.path.isfile(x) for x in s])