def forward()

in curiosity/bert.py [0:0]


    def forward(self, tokens: Dict[str, torch.LongTensor]) -> torch.Tensor:
        # pylint: disable=arguments-differ
        input_ids = tokens[self._index]
        token_type_ids = tokens[f"{self._index}-type-ids"]
        input_mask = (input_ids != 0).long()
        # transformers lib doesn't like extra dimensions, and TimeDistributed
        # expects a tensor
        # This works since we only need independent encodings of each piece of text
        if input_ids.dim() > 2:
            shape = input_ids.shape
            word_dim = shape[-1]
            reshaped_input_ids = input_ids.view(-1, word_dim)
            reshaped_token_type_ids = token_type_ids.view(-1, word_dim)
            reshaped_input_mask = input_mask.view(-1, word_dim)
            _, reshaped_pooled = self.bert_model(
                input_ids=reshaped_input_ids,
                token_type_ids=reshaped_token_type_ids,
                attention_mask=reshaped_input_mask,
            )
            pooled = reshaped_pooled.view(shape[:-1] + (-1,))
        else:
            _, pooled = self.bert_model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=input_mask,
            )
        # Current mask is wordpiece mask, we want an utterance mask
        # So search for utterances with all masked wordpieces
        utter_mask = (input_mask.sum(dim=-1) != 0).long()

        return pooled, utter_mask