in empchat/bert_local.py [0:0]
def __init__(self, opt, dictionary):
from parlai.agents.bert_ranker.helpers import BertWrapper
try:
from pytorch_pretrained_bert import BertModel
except ImportError:
raise Exception(
"BERT rankers needs pytorch-pretrained-BERT installed. "
"\npip install pytorch-pretrained-bert"
)
super().__init__()
self.opt = opt
self.pad_idx = dictionary[PAD_TOKEN]
self.ctx_bert = BertWrapper(
bert_model=BertModel.from_pretrained(BERT_ID),
output_dim=opt.bert_dim,
add_transformer_layer=opt.bert_add_transformer_layer,
)
self.cand_bert = BertWrapper(
bert_model=BertModel.from_pretrained(BERT_ID),
output_dim=opt.bert_dim,
add_transformer_layer=opt.bert_add_transformer_layer,
)
# Reset the embeddings for the until-now unused BERT tokens
orig_embedding_weights = BertModel.from_pretrained(
BERT_ID
).embeddings.word_embeddings.weight
mean_val = orig_embedding_weights.mean().item()
std_val = orig_embedding_weights.std().item()
unused_tokens = [START_OF_COMMENT, PARLAI_PAD_TOKEN, EMPTYPERSONA_TOKEN]
unused_token_idxes = [dictionary[token] for token in unused_tokens]
for token_idx in unused_token_idxes:
rand_embedding = orig_embedding_weights.new_empty(
(1, orig_embedding_weights.size(1))
).normal_(mean=mean_val, std=std_val)
for embeddings in [
self.ctx_bert.bert_model.embeddings.word_embeddings,
self.cand_bert.bert_model.embeddings.word_embeddings,
]:
embeddings.weight[token_idx] = rand_embedding
self.ctx_bert.bert_model.embeddings.word_embeddings.weight.detach_()
self.cand_bert.bert_model.embeddings.word_embeddings.weight.detach_()