def forward()

in low_rank_comparisons/src/model.py [0:0]


    def forward(self, w, r, attn_mask=None, mems=None):
        # def forward(self, _h, _c):
        #     if self.pre_hnorm:
        #         h = self.h_layer_norm(_h)
        #     else:
        #         h = _h
        #     if self.pre_cnorm:
        #         c = self.c_layer_norm(_c)
        #     else:
        #         c = _c
        #     head_q = self.q_net(h)
        #     head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
        #     head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
        #     head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
        #     head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
        ###################### run through the memory component.
        #qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        qlen = w.size(0)
        if r is not None:
            rlen = r.size(0)
        else:
            rlen = 0

        bsz = w.size(1)

        if self.cross_attn:
            if self.pre_cnorm:
                mems = self.c_layer_norm(mems)

            # qlen, bsz, dim
            w_head_q = self.q_net(w)

            w_head_k, w_head_v = torch.chunk(self.kv_net(mems), 2, dim=-1)

            if r is not None:
                r_head_k = self.r_net(r)
                if self.rel_pos == 'full':
                    r_head_q = self.r_q_net(r) #
        else:
            #assert qlen == rlen
            # w: qlen, bsz, dim
            if mems is not None:
                if mems.dtype != w.dtype:
                    mems = mems.half()

                cat = torch.cat([mems, w], 0)
                if self.pre_lnorm:
                    w_heads = self.qkv_net(self.layer_norm(cat))
                else:
                    w_heads = self.qkv_net(cat)
                r_head_k = self.r_net(r)  #
                if self.rel_pos == 'full':
                    r_head_q = self.r_q_net(r) #
                w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
                w_head_q = w_head_q[-qlen:]
            else:
                if self.pre_lnorm:
                    w_heads = self.qkv_net(self.layer_norm(w))
                else:
                    w_heads = self.qkv_net(w)
                # position_embedding.
                # k_len, 1, dim
                r_head_k = self.r_net(r)

                if self.rel_pos == 'full':
                    r_head_q = self.r_q_net(r)
                # q, k, v
                w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
        klen = w_head_k.size(0)

        #if torch.isnan(self.r_r_bias.mean()):
        #    print('r_r_bias has nan', self.r_r_bias)
        #else:
        #    print('r_r_bias has no nan.', self.r_r_bias)
        #if torch.isnan(self.r_w_bias.mean()):
        #    print('r_w_bias has nan', self.r_w_bias)
        #else:
        #    print('r_w_bias has no nan', self.r_w_bias)
        #assert qlen == rlen
        # w: qlen, bsz, dim
        #if torch.isnan(w_head_q.mean()):
        #    print('w_head_q has nan.')
        #if torch.isnan(w_head_q.mean()):
        #    print('w_head_q has nan.')

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # klen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # klen x bsz x n_head x d_head

        if r is not None:
            r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # rlen x n_head x d_head
            if self.rel_pos == 'full':
                r_head_q = r_head_q.view(rlen, self.n_head, self.d_head)     

            assert rlen == klen or rlen == klen + klen - 1 or rlen == qlen + klen

        #### compute attention score
        # r_w_bias is useless here.
        #r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # rlen x n_head x d_head
        #assert rlen == klen
        ###########full attention
        
        # BC : position to context # klen x klen x bsz x n_head                 
        #print('r_head_q', r_head_q)
        #print('w_head_k', w_head_k)                                            
        
        #attn_score.mul_(self.scale)
        # best transformer-xl: 19.2 vs 19.4 vs reported 18.3 : how about large batch size. 
        #### compute attention probability
        # attn_mask : qlen, ken, 1

        ######################################
        #if self.r_w_bias.dtype != w_head_q.dtype:
        #    self.r_w_bias = self.r_w_bias.half()
        rw_head_q = w_head_q + self.r_w_bias                                         # qlen x bsz x n_head x d_head
        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q.float(), w_head_k.float()))  # qlen x klen x bsz x n_head

        if r is not None:
            rr_head_q = w_head_q + self.r_r_bias                                         # qlen x bsz x n_head x d_head
            BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q.float(), r_head_k.float()))   # qlen x rlen x bsz x n_head

            if rlen == klen:
                BD = self._rel_shift(BD)
            elif rlen == klen + klen - 1:
                assert klen == qlen
                BD = self._rel_shift_bias(BD, klen)
            elif rlen == qlen + klen:
                BD = self._rel_shift_bias(BD, klen)

            if self.rel_pos == 'full':
                BC = torch.einsum('ind,jbnd->ijbn', (r_head_q.float(), w_head_k.float()))               # rlen x klen x bsz x n_head

                #assert klen == qlen
                if rlen == klen:
                    BC = self._rel_shift_trans(BC, qlen)                                    # rlen x klen x bsz x n_head
                    BC = self._rel_shift(BC)
                    #def _rel_shift_trans(self, x, qlen):
                    # [qlen x klen x bsz x n_head]
                    #attn_score = AC + BD + BC #* self.scale #/10.0
                elif rlen == klen + klen - 1:
                    #assert klen == qlen
                    BC = torch.einsum('ind,jbnd->ijbn', (r_head_q.float(), w_head_k.float()))               # rlen x klen x bsz x n_head
                    BC = self._rel_shift_trans(BC, qlen)                                    # rlen x klen x bsz x n_head
                    BC = self._rel_shift(BC)[:, :klen]
                    #def _rel_shift_trans(self, x, qlen):
                    # [qlen x klen x bsz x n_head]
                    #attn_score = AC + BD + BC #* self.scale #/10.0
                elif rlen == klen + qlen:
                    BC = self._rel_shift_trans_bias(BC, qlen)

                attn_score = AC + BD + BC

            else:
                #def _rel_shift_bias(self, x):
                #if torch.isnan(BD.mean()):
                #    print('BD has nan.')
                # [qlen x klen x bsz x n_head]
                attn_score = AC + BD
        else:
            attn_score = AC

        #if torch.isnan(attn_score.mean()):
        #    print('attn score (before scale) has nan.', self.scale)
        attn_score.mul_(self.scale)
        
        #if torch.isnan(attn_score.mean()):
        #    print('attn score (before mask) has nan.')

        #### compute attention probability
        # attn_mask : qlen, ken, 1
        if attn_mask is not None and attn_mask.any().item():
            # attn_mask : klen, bsz (used for maskedlm.)
            if attn_mask.dim() == 2:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score)
            # attn_mask : qlen, klen, 1 (used for lm.)       
            elif attn_mask.dim() == 3:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score)

        # [qlen x klen x bsz x n_head]
        #if torch.isnan(attn_score.mean()):
        #    print('attn score (after mask) has nan.')

        attn_prob = F.softmax(attn_score, dim=1)

        #if torch.isnan(attn_prob.mean()):
        #    print('attn prob has nan.')
            
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        # [qlen x bsz x n_head x d_head]        
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v.type_as(attn_prob))).type_as(w)

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)#.type_as(w)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        return output