in lama/modules/bert_connector.py [0:0]
def __get_input_tensors_batch(self, sentences_list):
tokens_tensors_list = []
segments_tensors_list = []
masked_indices_list = []
tokenized_text_list = []
max_tokens = 0
for sentences in sentences_list:
tokens_tensor, segments_tensor, masked_indices, tokenized_text = self.__get_input_tensors(sentences)
tokens_tensors_list.append(tokens_tensor)
segments_tensors_list.append(segments_tensor)
masked_indices_list.append(masked_indices)
tokenized_text_list.append(tokenized_text)
# assert(tokens_tensor.shape[1] == segments_tensor.shape[1])
if (tokens_tensor.shape[1] > max_tokens):
max_tokens = tokens_tensor.shape[1]
# print("MAX_TOKENS: {}".format(max_tokens))
# apply padding and concatenate tensors
# use [PAD] for tokens and 0 for segments
final_tokens_tensor = None
final_segments_tensor = None
final_attention_mask = None
for tokens_tensor, segments_tensor in zip(tokens_tensors_list, segments_tensors_list):
dim_tensor = tokens_tensor.shape[1]
pad_lenght = max_tokens - dim_tensor
attention_tensor = torch.full([1,dim_tensor], 1, dtype= torch.long)
if pad_lenght>0:
pad_1 = torch.full([1,pad_lenght], self.pad_id, dtype= torch.long)
pad_2 = torch.full([1,pad_lenght], 0, dtype= torch.long)
attention_pad = torch.full([1,pad_lenght], 0, dtype= torch.long)
tokens_tensor = torch.cat((tokens_tensor,pad_1), dim=1)
segments_tensor = torch.cat((segments_tensor,pad_2), dim=1)
attention_tensor = torch.cat((attention_tensor,attention_pad), dim=1)
if final_tokens_tensor is None:
final_tokens_tensor = tokens_tensor
final_segments_tensor = segments_tensor
final_attention_mask = attention_tensor
else:
final_tokens_tensor = torch.cat((final_tokens_tensor,tokens_tensor), dim=0)
final_segments_tensor = torch.cat((final_segments_tensor,segments_tensor), dim=0)
final_attention_mask = torch.cat((final_attention_mask,attention_tensor), dim=0)
# print(final_tokens_tensor)
# print(final_segments_tensor)
# print(final_attention_mask)
# print(final_tokens_tensor.shape)
# print(final_segments_tensor.shape)
# print(final_attention_mask.shape)
return final_tokens_tensor, final_segments_tensor, final_attention_mask, masked_indices_list, tokenized_text_list