recipes/lexicon_free/utilities/convlm_utils.py (122 lines of code) (raw):

""" Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. This source code is licensed under the MIT-style license found in the LICENSE file in the root directory of this source tree. """ from __future__ import absolute_import, division, print_function, unicode_literals import torch from fairseq import options from fairseq.models.fconv import FConvDecoder EOS = '</s>' UNK = '<unk>' EOSIDX = 2 UNKIDX = 3 def compute_new_state(model_state): new_state = dict() for key, val in model_state["model"].items(): if "1.weight" in key and "adaptive" in key: new_state[ ".".join(key.split(".")[1:]).replace("1.weight", "2.weight") ] = val else: new_state[".".join(key.split(".")[1:])] = val return new_state def load_char_model_20B(pytorch_model_path, fairseq_dict, dataset_type): layer = eval( "[(512, 5)] + [(128, 1, 0), (128, 5, 0), (256, 1, 3)] * 3 + " "[(256, 1, 0), (256, 5, 0), (512, 1, 3)] * 3 + " "[(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3 + " "[(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 9 + " "[(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]" ) model_state = torch.load(pytorch_model_path) convLM_char = FConvDecoder( fairseq_dict, embed_dim=256, out_embed_dim=256, max_positions=1024, convolutions=layer, dropout=0.1, share_embed=False, attention=False, positional_embeddings=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, ).cuda() convLM_char.load_state_dict(compute_new_state(model_state)) convLM_char.eval() return convLM_char def load_char_model_14B(pytorch_model_path, fairseq_dict, dataset_type): layer = eval( "[(512, 5)] + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3 + " "[(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3 + " "[(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6 + " "[(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]" ) model_state = torch.load(pytorch_model_path) convLM_char = FConvDecoder( fairseq_dict, embed_dim=128, out_embed_dim=128, max_positions=1024, convolutions=layer, dropout=0.1, share_embed=False, attention=False, positional_embeddings=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, ).cuda() convLM_char.load_state_dict(compute_new_state(model_state)) convLM_char.eval() return convLM_char def load_word_model(pytorch_model_path, fairseq_dict, dataset_type): layer = eval( "[(512, 5)] + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3 + " "[(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3 + " "[(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6 + " "[(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]" ) model_state = torch.load(pytorch_model_path) if dataset_type == "wsj": cutoff = "10000,50000,100000" elif dataset_type == "ls": cutoff = "10000,50000,200000" else: cutoff = "" convLM = FConvDecoder( fairseq_dict, embed_dim=128, out_embed_dim=128, max_positions=1024, convolutions=layer, dropout=0.1, share_embed=False, attention=False, positional_embeddings=False, adaptive_softmax_cutoff=(options.eval_str_list(cutoff, type=int)), adaptive_softmax_dropout=0, ).cuda() convLM.load_state_dict(compute_new_state(model_state)) convLM.eval() convLM.adaptive_softmax.eval() return convLM def decodeInputText(sentences, token_indices_dict): sentences_decoded = [] for line in sentences: sentences_decoded.append( [ token_indices_dict[UNK] if token not in token_indices_dict else token_indices_dict[token] for token in line.split(" ") ] ) return sentences_decoded def build_token_index_correspondence(dict_fname): # follow fairseq token_indices_dict = dict() indices_token_dict = dict() with open(dict_fname, "r") as f: for index, line in enumerate(f): token_indices_dict[line.strip().split(" ")[0]] = index + 4 indices_token_dict[index + 4] = line.strip().split(" ")[0] token_indices_dict[EOS] = EOSIDX indices_token_dict[EOSIDX] = EOS token_indices_dict[UNK] = UNKIDX indices_token_dict[UNKIDX] = UNK return token_indices_dict, indices_token_dict