in torchtext/models/roberta/modules.py [0:0]
def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]:
if attn_mask is not None:
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
padding_mask = tokens.eq(self.padding_idx)
token_embeddings = self.token_embedding(tokens)
embedded_positions = self.positional_embedding(tokens)
embedded = token_embeddings + embedded_positions
if not hasattr(self, "normalize_before"):
self.normalize_before = False
if not self.normalize_before:
embedded = self.embedding_layer_norm(embedded)
embedded = self.dropout(embedded)
padded_embedded = embedded * (1 - padding_mask.unsqueeze(-1).type_as(embedded))
encoded = padded_embedded.transpose(0, 1)
if self.return_all_layers:
states = [encoded]
for layer in self.layers:
encoded = layer(encoded, padding_mask, attn_mask)
states.append(encoded)
if self.normalize_before:
for i, state in enumerate(states):
states[i] = self.embedding_layer_norm(state)
# states are returned as T x B x C
return states
else:
for layer in self.layers:
encoded = layer(encoded, padding_mask, attn_mask)
if self.normalize_before:
encoded = self.embedding_layer_norm(encoded)
# states are returned as T x B x C
return encoded