in data.py [0:0]
def extract_be(self):
"""
Extract bert embeddings (common for all types of datasets)
:return:
"""
dialogs = self.dialogs
self.logbook.write_message_logs("Tokenizing ...")
tdialogs = [[self.tokenizer.tokenize(utt) for utt in dl] for dl in dialogs]
index_dial = [
[self.tokenizer.convert_tokens_to_ids(utt) for utt in dl] for dl in tdialogs
]
segment_dial = [[[0 for tok in utt] for utt in dl] for dl in index_dial]
# initalize bert model
self.init_bert_model()
self.logbook.write_message_logs("Extracting {} dialogs".format(len(dialogs)))
pb = tqdm(total=len(dialogs))
for di, dial in enumerate(index_dial):
utt_vecs = []
for uid, utt in enumerate(dial):
tokens_tensor = torch.tensor([utt])
segments_tensor = torch.tensor([segment_dial[di][uid]])
tokens_tensor = tokens_tensor.to("cuda")
segments_tensor = segments_tensor.to("cuda")
with torch.no_grad():
outs = self.bert(tokens_tensor, token_type_ids=segments_tensor)
utt_vecs.append(outs[1][0].to("cpu"))
self.dial_vecs.append(utt_vecs)
pb.update(1)
pb.close()