def benchmark_mha_block()

in benchmark/mha_block.py [0:0]


def benchmark_mha_block():

    def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None):
        # Build torchtext MultiheadAttention module
        in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
                                            torch.nn.Linear(embed_dim, embed_dim),
                                            torch.nn.Linear(embed_dim, embed_dim))
        MHA = MultiheadAttentionContainer(nhead, in_proj_container,
                                          ScaledDotProduct(),
                                          torch.nn.Linear(embed_dim, embed_dim)).to(device)

        query = torch.rand((tgt_len, bsz, embed_dim)).to(device)
        if src_len is None:
            key = value = query
            src_len = tgt_len
        else:
            key = value = torch.rand((src_len, bsz, embed_dim)).to(device)
        attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device)
        attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead))
        bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device)
        print("starting torchtext.modules.MultiheadAttentionContainer")
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        t0 = time.monotonic()
        for _ in range(100):
            mha_output, attn_weights = MHA(query, key, value,
                                           attn_mask=attn_mask,
                                           bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
                                           bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        print(time.monotonic() - t0)

        # Use torch.nn.functional.multi_head_attention_forward
        torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float('-inf'))
        print("starting torch.nn.functional.multi_head_attention_forward")
        in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight,
                                    MHA.in_proj_container.key_proj.weight,
                                    MHA.in_proj_container.value_proj.weight])
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        t0 = time.monotonic()
        for _ in range(100):
            torch_mha_output, torch_mha_weights = mha_forward(query, key, value,
                                                              embed_dim, nhead,
                                                              in_proj_weight, None,
                                                              bias_k, bias_v,
                                                              False, 0.0,
                                                              MHA.out_proj.weight,
                                                              MHA.out_proj.bias,
                                                              attn_mask=torch_attn_mask)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        print(time.monotonic() - t0)

    # GPU test
    device = torch.device("cuda")
    for embed_dim in [64, 768]:
        for nhead in [2, 16]:
            for seq_len in [10, 128, 1000]:
                for bsz in [2, 72]:
                    if seq_len == 1000 and bsz == 72:
                        continue
                    print("*" * 80)
                    print("test case GPU with embed_dim, nhead, seq_len, bsz:",
                          embed_dim, nhead, seq_len, seq_len, bsz)
                    _run_benchmark(embed_dim, nhead, bsz, device, seq_len, seq_len)

    # GPU test for self-attention
    device = torch.device("cuda")
    for embed_dim in [64, 256]:
        for nhead in [2, 16]:
            for seq_len in [10, 128, 1000]:
                for bsz in [2, 72]:
                    if seq_len == 1000 and bsz == 72:
                        continue
                    print("*" * 80)
                    print("self-attention test case GPU with embed_dim, nhead, seq_len, bsz:",
                          embed_dim, nhead, seq_len, seq_len, bsz)
                    _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)

    # CPU test for self-attention
    device = torch.device("cpu")
    for embed_dim in [64, 768]:
        for nhead in [2, 16]:
            for seq_len in [10, 128, 1000]:
                for bsz in [2, 72]:
                    if seq_len == 1000 and bsz == 72:
                        continue
                    print("*" * 80)
                    print("test case CPU with embed_dim, nhead, seq_len, bsz:",
                          embed_dim, nhead, seq_len, seq_len, bsz)
                    _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)