in longform-qa/lfqa_utils.py [0:0]
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
# reproduces BERT forward pass with checkpointing
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
return self.sent_encoder(input_ids, attention_mask=attention_mask)[1]
else:
# prepare implicit variables
device = input_ids.device
input_shape = input_ids.size()
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
)
# define function for checkpointing
def partial_encode(*inputs):
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
sequence_output = encoder_outputs[0]
pooled_output = self.sent_encoder.pooler(sequence_output)
return pooled_output
# run embedding layer on everything at once
embedding_output = self.sent_encoder.embeddings(
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
)
# run encoding and pooling on one mini-batch at a time
pooled_output_list = []
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
pooled_output_list.append(pooled_output)
return torch.cat(pooled_output_list, dim=0)