def forward()

in torchtext/models/roberta/modules.py [0:0]


    def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]:
        if attn_mask is not None:
            torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")

        padding_mask = tokens.eq(self.padding_idx)

        token_embeddings = self.token_embedding(tokens)
        embedded_positions = self.positional_embedding(tokens)

        embedded = token_embeddings + embedded_positions

        if not hasattr(self, "normalize_before"):
            self.normalize_before = False
        if not self.normalize_before:
            embedded = self.embedding_layer_norm(embedded)
        embedded = self.dropout(embedded)

        padded_embedded = embedded * (1 - padding_mask.unsqueeze(-1).type_as(embedded))

        encoded = padded_embedded.transpose(0, 1)

        if self.return_all_layers:
            states = [encoded]

            for layer in self.layers:
                encoded = layer(encoded, padding_mask, attn_mask)
                states.append(encoded)

            if self.normalize_before:
                for i, state in enumerate(states):
                    states[i] = self.embedding_layer_norm(state)

            # states are returned as T x B x C
            return states
        else:
            for layer in self.layers:
                encoded = layer(encoded, padding_mask, attn_mask)

            if self.normalize_before:
                encoded = self.embedding_layer_norm(encoded)

            # states are returned as T x B x C
            return encoded