in lama/modules/bert_connector.py [0:0]
def __init__(self, args, vocab_subset = None):
super().__init__()
bert_model_name = args.bert_model_name
dict_file = bert_model_name
if args.bert_model_dir is not None:
# load bert model from file
bert_model_name = str(args.bert_model_dir) + "/"
dict_file = bert_model_name+args.bert_vocab_name
self.dict_file = dict_file
print("loading BERT model from {}".format(bert_model_name))
else:
# load bert model from huggingface cache
pass
# When using a cased model, make sure to pass do_lower_case=False directly to BaseTokenizer
do_lower_case = False
if 'uncased' in bert_model_name:
do_lower_case=True
# Load pre-trained model tokenizer (vocabulary)
self.tokenizer = BertTokenizer.from_pretrained(dict_file)
# original vocab
self.map_indices = None
self.vocab = list(self.tokenizer.ids_to_tokens.values())
self._init_inverse_vocab()
# Add custom tokenizer to avoid splitting the ['MASK'] token
custom_basic_tokenizer = CustomBaseTokenizer(do_lower_case = do_lower_case)
self.tokenizer.basic_tokenizer = custom_basic_tokenizer
# Load pre-trained model (weights)
# ... to get prediction/generation
self.masked_bert_model = BertForMaskedLM.from_pretrained(bert_model_name)
self.masked_bert_model.eval()
# ... to get hidden states
self.bert_model = self.masked_bert_model.bert
self.pad_id = self.inverse_vocab[BERT_PAD]
self.unk_index = self.inverse_vocab[BERT_UNK]