def init_fairseq_lm_globals()

in tseval/models/language_models.py [0:0]


def init_fairseq_lm_globals():
    print('Loading fairseq language model...')
    fairseq_lm_dir = os.path.join(MODELS_DIR, 'language_models/wiki103_fconv_lm')
    checkpoint_path = os.path.join(fairseq_lm_dir, 'wiki103.pt')
    data_path = os.path.join(MODELS_DIR, 'language_models/wiki103_test_lm')
    if not os.path.exists(data_path):
        from tseval.utils.prepare import prepare_resource
        prepare_resource('fairseq_lm')
    global FAIRSEQ_MODEL, DICTIONARY, DEVICE
    FAIRSEQ_MODEL, DICTIONARY = load_fairseq_lm_model_and_dict(checkpoint_path, data_path)
    FAIRSEQ_MODEL.make_generation_fast_()
    DEVICE = torch.device('cuda')
    FAIRSEQ_MODEL = FAIRSEQ_MODEL.to(DEVICE)
    print('Done.')