recipes/lexicon_free/utilities/convlm_utils.py (122 lines of code) (raw):
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the MIT-style license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from fairseq import options
from fairseq.models.fconv import FConvDecoder
EOS = '</s>'
UNK = '<unk>'
EOSIDX = 2
UNKIDX = 3
def compute_new_state(model_state):
new_state = dict()
for key, val in model_state["model"].items():
if "1.weight" in key and "adaptive" in key:
new_state[
".".join(key.split(".")[1:]).replace("1.weight", "2.weight")
] = val
else:
new_state[".".join(key.split(".")[1:])] = val
return new_state
def load_char_model_20B(pytorch_model_path, fairseq_dict, dataset_type):
layer = eval(
"[(512, 5)] + [(128, 1, 0), (128, 5, 0), (256, 1, 3)] * 3 + "
"[(256, 1, 0), (256, 5, 0), (512, 1, 3)] * 3 + "
"[(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3 + "
"[(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 9 + "
"[(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]"
)
model_state = torch.load(pytorch_model_path)
convLM_char = FConvDecoder(
fairseq_dict,
embed_dim=256,
out_embed_dim=256,
max_positions=1024,
convolutions=layer,
dropout=0.1,
share_embed=False,
attention=False,
positional_embeddings=False,
adaptive_softmax_cutoff=None,
adaptive_softmax_dropout=0,
).cuda()
convLM_char.load_state_dict(compute_new_state(model_state))
convLM_char.eval()
return convLM_char
def load_char_model_14B(pytorch_model_path, fairseq_dict, dataset_type):
layer = eval(
"[(512, 5)] + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3 + "
"[(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3 + "
"[(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6 + "
"[(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]"
)
model_state = torch.load(pytorch_model_path)
convLM_char = FConvDecoder(
fairseq_dict,
embed_dim=128,
out_embed_dim=128,
max_positions=1024,
convolutions=layer,
dropout=0.1,
share_embed=False,
attention=False,
positional_embeddings=False,
adaptive_softmax_cutoff=None,
adaptive_softmax_dropout=0,
).cuda()
convLM_char.load_state_dict(compute_new_state(model_state))
convLM_char.eval()
return convLM_char
def load_word_model(pytorch_model_path, fairseq_dict, dataset_type):
layer = eval(
"[(512, 5)] + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3 + "
"[(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3 + "
"[(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6 + "
"[(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]"
)
model_state = torch.load(pytorch_model_path)
if dataset_type == "wsj":
cutoff = "10000,50000,100000"
elif dataset_type == "ls":
cutoff = "10000,50000,200000"
else:
cutoff = ""
convLM = FConvDecoder(
fairseq_dict,
embed_dim=128,
out_embed_dim=128,
max_positions=1024,
convolutions=layer,
dropout=0.1,
share_embed=False,
attention=False,
positional_embeddings=False,
adaptive_softmax_cutoff=(options.eval_str_list(cutoff, type=int)),
adaptive_softmax_dropout=0,
).cuda()
convLM.load_state_dict(compute_new_state(model_state))
convLM.eval()
convLM.adaptive_softmax.eval()
return convLM
def decodeInputText(sentences, token_indices_dict):
sentences_decoded = []
for line in sentences:
sentences_decoded.append(
[
token_indices_dict[UNK]
if token not in token_indices_dict
else token_indices_dict[token]
for token in line.split(" ")
]
)
return sentences_decoded
def build_token_index_correspondence(dict_fname):
# follow fairseq
token_indices_dict = dict()
indices_token_dict = dict()
with open(dict_fname, "r") as f:
for index, line in enumerate(f):
token_indices_dict[line.strip().split(" ")[0]] = index + 4
indices_token_dict[index + 4] = line.strip().split(" ")[0]
token_indices_dict[EOS] = EOSIDX
indices_token_dict[EOSIDX] = EOS
token_indices_dict[UNK] = UNKIDX
indices_token_dict[UNKIDX] = UNK
return token_indices_dict, indices_token_dict