in lama/modules/roberta_connector.py [0:0]
def get_batch_generation(self, sentences_list, logger=None, try_cuda=True):
if not sentences_list:
return None
if try_cuda:
self.try_cuda()
tensor_list = []
masked_indices_list = []
max_len = 0
output_tokens_list = []
for masked_inputs_list in sentences_list:
tokens_list = []
for idx, masked_input in enumerate(masked_inputs_list):
# 2. sobstitute [MASK] with <mask>
masked_input = masked_input.replace(MASK, ROBERTA_MASK)
text_spans = masked_input.split(ROBERTA_MASK)
text_spans_bpe = (
(" {0} ".format(ROBERTA_MASK))
.join(
[
self.bpe.encode(text_span.rstrip())
for text_span in text_spans
]
)
.strip()
)
prefix = ""
if idx == 0:
prefix = ROBERTA_START_SENTENCE
tokens_list.append(
self.task.source_dictionary.encode_line(
str(prefix + " " + text_spans_bpe).strip(), append_eos=True
)
)
tokens = torch.cat(tokens_list)[: self.max_sentence_length]
output_tokens_list.append(tokens.long().cpu().numpy())
if len(tokens) > max_len:
max_len = len(tokens)
tensor_list.append(tokens)
masked_index = (tokens == self.task.mask_idx).nonzero().numpy()
for x in masked_index:
masked_indices_list.append([x[0]])
pad_id = self.task.source_dictionary.pad()
tokens_list = []
for tokens in tensor_list:
pad_lenght = max_len - len(tokens)
if pad_lenght > 0:
pad_tensor = torch.full([pad_lenght], pad_id, dtype=torch.int)
tokens = torch.cat((tokens, pad_tensor))
tokens_list.append(tokens)
batch_tokens = torch.stack(tokens_list)
with torch.no_grad():
# with utils.eval(self.model.model):
self.model.eval()
self.model.model.eval()
log_probs, extra = self.model.model(
batch_tokens.long().to(device=self._model_device),
features_only=False,
return_all_hiddens=False,
)
return log_probs.cpu(), output_tokens_list, masked_indices_list