def forward()

in models/transformer_seq.py [0:0]


    def forward(self, query, key, value):
        # query = B x M x H
        # key, value = B x (M+L) x H
        aux_loss = 0

        key_pe, val_pe = self.key_pe, self.val_pe
        if self.args.adapt_span:
            key, value, key_pe, val_pe = self.adaptive_span.trim_memory(
                key, value, key_pe, val_pe
            )

        attn = 0

        # compute attention from context
        attn = torch.matmul(
            query, key.transpose(-1, -2)
        )  # B x M (dest) x (M+L) (src)
        attn = unskew(attn)  # B x M x L

        # compute the effect of position embedding
        attn = attn + torch.matmul(query, key_pe)  # B x M x L

        attn = attn / math.sqrt(self.args.head_dim)  # B x M X L
        attn = F.softmax(attn, dim=-1)
        if self.args.adapt_span:
            attn = self.adaptive_span(attn)
            attn = attn / (attn.sum(-1, keepdim=True) + 1e-8)
        attn = self.dropout(attn)  # B x M X L

        out = 0
        attn_cont = skew(attn, 0)  # B x M X (L+M)
        out = out + torch.matmul(attn_cont, value)  # B x M x H

        return out, aux_loss