in curiosity/bert.py [0:0]
def forward(self, tokens: Dict[str, torch.LongTensor]) -> torch.Tensor:
# pylint: disable=arguments-differ
input_ids = tokens[self._index]
token_type_ids = tokens[f"{self._index}-type-ids"]
input_mask = (input_ids != 0).long()
# transformers lib doesn't like extra dimensions, and TimeDistributed
# expects a tensor
# This works since we only need independent encodings of each piece of text
if input_ids.dim() > 2:
shape = input_ids.shape
word_dim = shape[-1]
reshaped_input_ids = input_ids.view(-1, word_dim)
reshaped_token_type_ids = token_type_ids.view(-1, word_dim)
reshaped_input_mask = input_mask.view(-1, word_dim)
_, reshaped_pooled = self.bert_model(
input_ids=reshaped_input_ids,
token_type_ids=reshaped_token_type_ids,
attention_mask=reshaped_input_mask,
)
pooled = reshaped_pooled.view(shape[:-1] + (-1,))
else:
_, pooled = self.bert_model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=input_mask,
)
# Current mask is wordpiece mask, we want an utterance mask
# So search for utterances with all masked wordpieces
utter_mask = (input_mask.sum(dim=-1) != 0).long()
return pooled, utter_mask