def forward()

in torchtext/models/roberta/modules.py [0:0]


    def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        target_length, batch_size, embed_dim = query.size()
        mask_batch_size, source_length = key_padding_mask.size()

        torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match")
        torch._assert(
            batch_size == mask_batch_size,
            "query and key_padding_mask batch sizes differed",
        )

        projection = self.input_projection(query)
        q, k, v = projection.chunk(3, dim=-1)
        q = self.scaling * q

        batch_heads = batch_size * self.num_heads

        q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
        k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
        v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)

        torch._assert(
            k.size(1) == source_length, "key size should be equal to source length"
        )

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        if attn_mask is not None:
            torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
            torch._assert(attn_mask.size(0) == target_length, "attn_mask shape didn't match for target length {}".format(target_length))
            torch._assert(attn_mask.size(1) == source_length, "attn_mask shape didn't match for source length {}".format(source_length))
            torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
            if attn_mask.dtype == torch.bool:
                new_attn_mask = torch.zeros_like(attn_mask, dtype=query.dtype)
                new_attn_mask.masked_fill_(attn_mask, -1e8 if query.dtype == torch.float32 else -1e4)
                attn_mask = new_attn_mask
            attn_mask = attn_mask.unsqueeze(0)
            attn_weights += attn_mask

        torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
        torch._assert(
            attn_weights.size(0) == batch_heads,
            "attn_weights shape didn't match for batch heads",
        )
        torch._assert(
            attn_weights.size(1) == target_length,
            "attn_weights shape didn't match for target length",
        )
        torch._assert(
            attn_weights.size(2) == source_length,
            "attn_weights shape didn't match for source length",
        )

        attn_weights = attn_weights.view(
            batch_size, self.num_heads, target_length, source_length
        )
        attn_weights = attn_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
        )
        attn_weights = attn_weights.view(batch_heads, target_length, source_length)

        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
            attn_weights
        )
        attn_weights = self.dropout(attn_weights)

        attn = torch.bmm(attn_weights, v)

        torch._assert(
            attn.dim() == 3,
            "unexpected attn dim size",
        )
        torch._assert(
            attn.size(0) == batch_heads,
            "attn shape didn't match for batch heads",
        )
        torch._assert(
            attn.size(1) == target_length,
            "attn shape didn't match for target length",
        )
        torch._assert(
            attn.size(2) == self.head_dim,
            "attn shape didn't match for head dim",
        )
        attn = (
            attn.transpose(0, 1)
            .contiguous()
            .view(target_length, batch_size, self.head_dim * self.num_heads)
        )
        attn = self.output_projection(attn)

        return attn