def forward()

in intermediate_source/named_tensor_tutorial.py [0:0]


    def forward(self, query, key=None, value=None, mask=None):
        # (I)
        query = query.refine_names(..., 'T', 'D')
        self_attn = key is None and value is None
        if self_attn:
            mask = mask.refine_names(..., 'T')
        else:
            mask = mask.refine_names(..., 'T', 'T_key')  # enc attn

        dim = query.size('D')
        assert dim == self.dim, \
            f'Dimensions do not match: {dim} query vs {self.dim} configured'
        assert mask is not None, 'Mask is None, please specify a mask'
        n_heads = self.n_heads
        dim_per_head = dim // n_heads
        scale = math.sqrt(dim_per_head)

        # (II)
        def prepare_head(tensor):
            tensor = tensor.refine_names(..., 'T', 'D')
            return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
                          .align_to(..., 'H', 'T', 'D_head'))

        assert value is None
        if self_attn:
            key = value = query
        elif value is None:
            # key and value are the same, but query differs
            key = key.refine_names(..., 'T', 'D')
            value = key
        dim = key.size('D')

        # Distinguish between query_len (T) and key_len (T_key) dims.
        k = prepare_head(self.k_lin(key)).rename(T='T_key')
        v = prepare_head(self.v_lin(value)).rename(T='T_key')
        q = prepare_head(self.q_lin(query))

        dot_prod = q.div_(scale).matmul(k.align_to(..., 'D_head', 'T_key'))
        dot_prod.refine_names(..., 'H', 'T', 'T_key')  # just a check

        # (III)
        attn_mask = (mask == 0).align_as(dot_prod)
        dot_prod.masked_fill_(attn_mask, -float(1e20))

        attn_weights = self.attn_dropout(F.softmax(dot_prod / scale,
                                                   dim='T_key'))

        # (IV)
        attentioned = (
            attn_weights.matmul(v).refine_names(..., 'H', 'T', 'D_head')
            .align_to(..., 'T', 'H', 'D_head')
            .flatten(['H', 'D_head'], 'D')
        )

        return self.out_lin(attentioned).refine_names(..., 'T', 'D')