in fastchat/train/llama2_flash_attn_monkey_patch.py [0:0]
def test():
from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
from transformers.models.llama.configuration_llama import LlamaConfig
config = LlamaConfig(
hidden_size=1024,
intermediate_size=128,
num_hidden_layers=1,
num_attention_heads=8,
max_position_embeddings=16,
)
device = torch.device("cuda")
model = LlamaModel(config)
attn = LlamaAttention(config).to(device).half()
bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
-1, seqlen
)
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
for i in range(4):
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
if i:
mask[0, -i:] = False
mask[1, :i] = False
lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
ref, _, _ = attn.forward(
hidden, attention_mask=lmask, position_ids=position_ids
)
fast, _, _ = fastchat_forward(
attn, hidden, attention_mask=mask, position_ids=position_ids
)
lmask = _prepare_decoder_attention_mask(
model, mask, hidden.shape[:2], hidden, 0
)
test, _, _ = forward(
attn, hidden, attention_mask=lmask, position_ids=position_ids
)
print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
with torch.no_grad():
# Also check that past_kv is handled properly
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
part_len = seqlen // 4
assert part_len * 4 == seqlen
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
mask[0, -2:] = False
lmask = _prepare_decoder_attention_mask(
model, mask, hidden.shape[:2], hidden, 0
)
oneshot, _, _ = forward(
attn, hidden, attention_mask=lmask, position_ids=position_ids
)
parts = []
past_kv, past_kv_len = None, 0
for i in range(4):
start = part_len * i
end = start + part_len
hidden_part = hidden[:, start:end, ...]
lmask = _prepare_decoder_attention_mask(
model,
mask[:, start:end],
hidden_part.shape[:2],
hidden_part,
past_kv_len,
)
part, _, past_kv = forward(
attn,
hidden_part.clone(),
attention_mask=lmask,
position_ids=position_ids[:, start:end],
past_key_value=past_kv,
use_cache=True,
)
parts.append(part)
past_kv_len = past_kv[0].shape[2]
print(
f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
)
print(
f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
)