in XLM/src/model/__init__.py [0:0]
def check_model_params(params):
"""
Check models parameters.
"""
# masked language modeling task parameters
assert params.bptt >= 1
assert 0 <= params.word_pred < 1
assert 0 <= params.sample_alpha < 1
s = params.word_mask_keep_rand.split(',')
assert len(s) == 3
s = [float(x) for x in s]
assert all([0 <= x <= 1 for x in s]) and sum(s) == 1
params.word_mask = s[0]
params.word_keep = s[1]
params.word_rand = s[2]
# input sentence noise for DAE
if len(params.ae_steps) == 0:
assert params.word_shuffle == 0
assert params.word_dropout == 0
assert params.word_blank == 0
else:
assert params.word_shuffle == 0 or params.word_shuffle > 1
assert 0 <= params.word_dropout < 1
assert 0 <= params.word_blank < 1
# model dimensions
if params.emb_dim_encoder == 0 and params.emb_dim_decoder == 0:
assert params.emb_dim > 0
params.emb_dim_encoder = params.emb_dim
params.emb_dim_decoder = params.emb_dim
else:
assert params.emb_dim == 0
assert params.emb_dim_encoder > 0 and params.emb_dim_decoder > 0
if params.emb_dim_encoder == params.emb_dim_decoder:
params.emb_dim = params.emb_dim_decoder
else:
assert params.reload_emb == "", 'Pre-trained embeddings are not supported when the embedding size of the ' \
'encoder and the decoder do not match '
assert params.emb_dim_encoder % params.n_heads == 0
assert params.emb_dim_decoder % params.n_heads == 0
if params.n_layers_encoder == 0 and params.n_layers_decoder == 0:
assert params.n_layers > 0
params.n_layers_encoder = params.n_layers
params.n_layers_decoder = params.n_layers
else:
assert params.n_layers == 0
assert params.n_layers_encoder > 0 and params.n_layers_decoder > 0
# reload pretrained word embeddings
if params.reload_emb != '':
assert os.path.isfile(params.reload_emb)
# reload a pretrained model
if params.reload_model != '':
if params.encoder_only:
assert os.path.isfile(params.reload_model)
else:
s = params.reload_model.split(',')
assert len(s) == 2
assert all([x == '' or os.path.isfile(x) for x in s])
assert not (params.beam_size > 1 and params.number_samples >
1), 'Cannot sample when already doing beam search'
assert (params.eval_temperature is None) == (params.number_samples <=
1), 'Eval temperature should be set if and only if taking several samples at eval time'