def test()

in fastchat/train/llama2_flash_attn_monkey_patch.py [0:0]


def test():
    from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
    from transformers.models.llama.configuration_llama import LlamaConfig

    config = LlamaConfig(
        hidden_size=1024,
        intermediate_size=128,
        num_hidden_layers=1,
        num_attention_heads=8,
        max_position_embeddings=16,
    )
    device = torch.device("cuda")
    model = LlamaModel(config)
    attn = LlamaAttention(config).to(device).half()
    bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
    position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
        -1, seqlen
    )

    mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
    for i in range(4):
        hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
        if i:
            mask[0, -i:] = False
            mask[1, :i] = False

        lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
        ref, _, _ = attn.forward(
            hidden, attention_mask=lmask, position_ids=position_ids
        )

        fast, _, _ = fastchat_forward(
            attn, hidden, attention_mask=mask, position_ids=position_ids
        )

        lmask = _prepare_decoder_attention_mask(
            model, mask, hidden.shape[:2], hidden, 0
        )
        test, _, _ = forward(
            attn, hidden, attention_mask=lmask, position_ids=position_ids
        )

        print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
        print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
        print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
        print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
        print(f"allclose(fast, test) = {torch.allclose(fast, test)}")

    with torch.no_grad():
        # Also check that past_kv is handled properly
        hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
        part_len = seqlen // 4
        assert part_len * 4 == seqlen
        mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
        mask[0, -2:] = False
        lmask = _prepare_decoder_attention_mask(
            model, mask, hidden.shape[:2], hidden, 0
        )
        oneshot, _, _ = forward(
            attn, hidden, attention_mask=lmask, position_ids=position_ids
        )
        parts = []
        past_kv, past_kv_len = None, 0
        for i in range(4):
            start = part_len * i
            end = start + part_len
            hidden_part = hidden[:, start:end, ...]
            lmask = _prepare_decoder_attention_mask(
                model,
                mask[:, start:end],
                hidden_part.shape[:2],
                hidden_part,
                past_kv_len,
            )
            part, _, past_kv = forward(
                attn,
                hidden_part.clone(),
                attention_mask=lmask,
                position_ids=position_ids[:, start:end],
                past_key_value=past_kv,
                use_cache=True,
            )
            parts.append(part)
            past_kv_len = past_kv[0].shape[2]

        print(
            f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
        )
        print(
            f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
        )