def initialize_embeddings()

in NMT/src/model/pretrain_embeddings.py [0:0]


def initialize_embeddings(encoder, decoder, params, data):
    """
    Initialize the model with pretrained embeddings.
    """
    if params.pretrained_emb == '':
        return

    split = params.pretrained_emb.split(',')

    if len(split) == 1:
        assert os.path.isfile(params.pretrained_emb)
        pretrained_0, word2id_0 = reload_embeddings(params.pretrained_emb, params.emb_dim)
        pretrained = [pretrained_0 for _ in range(params.n_langs)]
        word2id = [word2id_0 for _ in range(params.n_langs)]
    else:
        assert len(split) == params.n_langs
        assert not params.share_lang_emb
        assert all(os.path.isfile(x) for x in split)
        pretrained = []
        word2id = []
        for path in split:
            pretrained_i, word2id_i = reload_embeddings(path, params.emb_dim)
            pretrained.append(pretrained_i)
            word2id.append(word2id_i)

    assert not params.share_lang_emb or all(data['dico'][params.langs[i]] == data['dico'][params.langs[0]] for i in range(1, params.n_langs))

    found = [0 for _ in range(params.n_langs)]
    lower = [0 for _ in range(params.n_langs)]

    # for every language
    for i, lang in enumerate(params.langs):

        # if everything is shared, we just need to do it for the first language
        if params.share_lang_emb and i > 0:
            break

        # define dictionary / parameters to update
        dico = data['dico'][lang]
        to_update = [encoder.embeddings[i].weight.data]
        if not params.share_encdec_emb:
            to_update.append(decoder.embeddings[i].weight.data)
        if not params.share_decpro_emb and params.pretrained_out:
            to_update.append(decoder.proj[i].weight.data)

        # for every word in that language
        for word_id in range(params.n_words[i]):
            word = dico[word_id]
            if word in word2id[i]:
                found[i] += 1
                vec = torch.from_numpy(pretrained[i][word2id[i][word]]).cuda()
                for x in to_update:
                    x[word_id] = vec
            elif word.lower() in word2id[i]:
                found[i] += 1
                lower[i] += 1
                vec = torch.from_numpy(pretrained[i][word2id[i][word.lower()]]).cuda()
                for x in to_update:
                    x[word_id] = vec

    # print summary
    for i, lang in enumerate(params.langs):
        _found = found[0 if params.share_lang_emb else i]
        _lower = lower[0 if params.share_lang_emb else i]
        logger.info(
            "Initialized %i / %i word embeddings for \"%s\" (including %i "
            "after lowercasing)." % (_found, params.n_words[i], lang, _lower)
        )