def sparse_attention_2d_light()

in modules/SwissArmyTransformer/sat/model/official/cuda2d_model.py [0:0]


def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, **kwargs):
    '''
    q0, k0, v0: [batch_size, 1088, hidden_size]
    q1, k1, v1: [batch_size, 4096, h2]
    n_head: int
    attention_mask: [batch_size, 1088, 1088]
    '''
    from sat.ops.local_attention_function import f_similar, f_weighting

    b, s0, h0 = q0.shape
    b, s1, h1 = q1.shape
    h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)

    q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
    v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
    k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
    
    # standard attention for level 0
    attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
    
    if log_attention_weights is not None:
        attention_scores += log_attention_weights

    attention_scores = torch.mul(attention_scores, attention_mask) - \
                    10000.0 * (1.0 - attention_mask)
    
    attention_probs0 = F.softmax(attention_scores, dim=-1)
    
    # local attention for level 1
    q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
    k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
    v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    

    # cross attention
    k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
    scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
    scores_1 = torch.cat(
        (
            scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]),
            scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
        ),
        dim=-1)
    attention_probs1 = F.softmax(scores_1, dim=-1)

    if attention_dropout is not None:
        with get_cuda_rng_tracker().fork():
            attention_probs0 = attention_dropout(attention_probs0)
            attention_probs1 = attention_dropout(attention_probs1)
        
    # weighting for level 0
    context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
    # weighting for level 1
    probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
    context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
    context1 = context1_to_1.view(b, n_head * h, l1**2)
    # weighting for cross attention
    probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
    v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
    context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
    context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
    context1 = context1 + context1_to_0
    return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)