in torchtext/models/roberta/modules.py [0:0]
def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
target_length, batch_size, embed_dim = query.size()
mask_batch_size, source_length = key_padding_mask.size()
torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match")
torch._assert(
batch_size == mask_batch_size,
"query and key_padding_mask batch sizes differed",
)
projection = self.input_projection(query)
q, k, v = projection.chunk(3, dim=-1)
q = self.scaling * q
batch_heads = batch_size * self.num_heads
q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
torch._assert(
k.size(1) == source_length, "key size should be equal to source length"
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
if attn_mask is not None:
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
torch._assert(attn_mask.size(0) == target_length, "attn_mask shape didn't match for target length {}".format(target_length))
torch._assert(attn_mask.size(1) == source_length, "attn_mask shape didn't match for source length {}".format(source_length))
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=query.dtype)
new_attn_mask.masked_fill_(attn_mask, -1e8 if query.dtype == torch.float32 else -1e4)
attn_mask = new_attn_mask
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
torch._assert(
attn_weights.size(0) == batch_heads,
"attn_weights shape didn't match for batch heads",
)
torch._assert(
attn_weights.size(1) == target_length,
"attn_weights shape didn't match for target length",
)
torch._assert(
attn_weights.size(2) == source_length,
"attn_weights shape didn't match for source length",
)
attn_weights = attn_weights.view(
batch_size, self.num_heads, target_length, source_length
)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
attn_weights = attn_weights.view(batch_heads, target_length, source_length)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
)
attn_weights = self.dropout(attn_weights)
attn = torch.bmm(attn_weights, v)
torch._assert(
attn.dim() == 3,
"unexpected attn dim size",
)
torch._assert(
attn.size(0) == batch_heads,
"attn shape didn't match for batch heads",
)
torch._assert(
attn.size(1) == target_length,
"attn shape didn't match for target length",
)
torch._assert(
attn.size(2) == self.head_dim,
"attn shape didn't match for head dim",
)
attn = (
attn.transpose(0, 1)
.contiguous()
.view(target_length, batch_size, self.head_dim * self.num_heads)
)
attn = self.output_projection(attn)
return attn