def multi_head_attention_forward()

in nestedtensor/nn/mha.py [0:0]


def multi_head_attention_forward(query,
                                 key,
                                 value,
                                 embed_dim_to_check,
                                 num_heads,
                                 in_proj_weight,
                                 in_proj_bias,
                                 bias_k,
                                 bias_v,
                                 add_zero_attn,
                                 dropout_p,
                                 out_proj_weight,
                                 out_proj_bias,
                                 training=True,
                                 key_padding_mask=None,
                                 need_weights=True,
                                 attn_mask=None,
                                 use_separate_proj_weight=False,
                                 q_proj_weight=None,
                                 k_proj_weight=None,
                                 v_proj_weight=None,
                                 static_k=None,
                                 static_v=None
                                 ):
    assert isinstance(query, nestedtensor.NestedTensor)
    assert isinstance(key, nestedtensor.NestedTensor)
    assert isinstance(value, nestedtensor.NestedTensor)
    assert torch.is_tensor(out_proj_weight)
    assert torch.is_tensor(out_proj_bias)

    # TODO: Explicitly unsupported flags
    assert not use_separate_proj_weight
    assert attn_mask is None
    assert key_padding_mask is None
    assert bias_k is None
    assert bias_v is None
    assert static_k is None
    assert static_v is None
    assert not add_zero_attn
    # assert not need_weights

    bsz, tgt_len, embed_dim = query.size()
    assert embed_dim == embed_dim_to_check
    # allow MHA to have different sizes for the feature dimension
    assert key.size(0) == value.size(0) and key.size(1) == value.size(1)

    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
    scaling = float(head_dim) ** -0.5

    if query is key and key is value and in_proj_weight.is_cuda:
        return torch.ops.nestedtensor.bt_min_mha(num_heads,
                                                 head_dim,
                                                 0.5,
                                                 False,
                                                 query,
                                                 query,
                                                 query,
                                                 in_proj_weight,
                                                 in_proj_bias,
                                                 scaling,
                                                 out_proj_weight,
                                                 in_proj_bias), None

    return nestedtensor.nested.nested._wrap_result(
        torch.ops.nestedtensor.min_mha(num_heads,
                                       head_dim,
                                       dropout_p,
                                       training,
                                       query._impl,
                                       key._impl,
                                       value._impl,
                                       in_proj_weight,
                                       in_proj_bias,
                                       scaling,
                                       out_proj_weight,
                                       out_proj_bias)), None