def check_mt_model_params()

in NMT/src/model/__init__.py [0:0]


def check_mt_model_params(params):
    """
    Check models parameters.
    """

    # shared layers
    assert 0 <= params.dropout < 1
    assert 0 <= params.share_enc <= params.n_enc_layers + int(params.attention and not params.transformer or not params.attention and params.proj_mode == 'proj')
    assert 0 <= params.share_dec <= params.n_dec_layers
    assert not params.share_decpro_emb or params.lstm_proj or getattr(params, 'transformer', False) or params.emb_dim == params.hidden_dim
    assert not params.share_output_emb or params.share_lang_emb
    assert (not (params.share_decpro_emb and params.share_lang_emb)) or params.share_output_emb
    assert not params.lstm_proj or not (params.attention and params.transformer)
    assert not params.share_lstm_proj or params.lstm_proj

    # attention model
    if params.attention:
        assert params.transformer or params.n_dec_layers > 1 or params.n_dec_layers == 1 and params.input_feeding
        assert params.transformer is False or params.emb_dim % params.encoder_attention_heads == 0
        assert params.transformer is False or params.emb_dim % params.decoder_attention_heads == 0
    # seq2seq model
    else:
        assert params.enc_dim == params.hidden_dim or not params.init_encoded
        assert params.enc_dim == params.hidden_dim or params.proj_mode == 'proj'
        assert params.proj_mode in ['proj', 'pool', 'last']

    # language model
    if params.lm_before == params.lm_after == 0:
        assert params.lm_share_enc == params.lm_share_dec == 0
        assert params.lm_share_emb is False and params.lm_share_proj is False
        assert params.lambda_lm == "0"
    else:
        assert not (params.attention and params.transformer)
        assert params.lm_share_enc <= 1 and params.lm_share_dec <= 1  # TODO: support more than one layer
        assert params.input_feeding is False or params.lm_share_dec == 0  # TODO: support input feeding mode
        assert 0 <= params.lm_share_enc <= params.n_enc_layers
        assert 0 <= params.lm_share_dec <= params.n_dec_layers
        assert (params.lm_share_enc + params.lm_share_dec > 0 or
                params.lm_share_emb or params.lm_share_proj)
        assert params.lambda_lm not in ["0", "-1"]
        assert params.lm_share_emb is False or not (params.freeze_enc_emb or params.freeze_dec_emb)

    # pretrained embeddings / freeze embeddings
    if params.pretrained_emb == '':
        assert not params.freeze_enc_emb or params.reload_enc
        assert not params.freeze_dec_emb or params.reload_dec
        assert not params.pretrained_out
    else:
        split = params.pretrained_emb.split(',')
        if len(split) == 1:
            assert os.path.isfile(params.pretrained_emb)
        else:
            assert len(split) == params.n_langs
            assert not params.share_lang_emb
            assert all(os.path.isfile(x) for x in split)
        if params.share_encdec_emb:
            assert params.freeze_enc_emb == params.freeze_dec_emb
        else:
            assert not (params.freeze_enc_emb and params.freeze_dec_emb)
        assert not (params.share_decpro_emb and params.freeze_dec_emb)
        assert not (params.share_decpro_emb and not params.pretrained_out)
        assert not params.pretrained_out or params.lstm_proj or getattr(params, 'transformer', False) or params.emb_dim == params.hidden_dim

    # discriminator parameters
    assert params.dis_layers >= 0
    assert params.dis_hidden_dim > 0
    assert 0 <= params.dis_dropout < 1
    assert params.dis_clip >= 0

    # reload MT model
    assert params.reload_model == '' or os.path.isfile(params.reload_model)
    assert not (params.reload_model != '') ^ (params.reload_enc or params.reload_dec or params.reload_dis)