in models/spatial/attention.py [0:0]
def test_attn_mask():
torch.set_default_dtype(torch.float64)
T, N, D = 8, 1, 20
attn_mask = torch.triu(torch.ones(T, T), diagonal=1) * -1e12
x = torch.randn(T * N * D).requires_grad_(True)
mha = L2MultiheadAttention(D, 1)
y = mha(x.reshape(T, N, D), attn_mask=attn_mask)
yhat = mha(x.reshape(T, N, D), attn_mask=attn_mask, rm_nonself_grads=True)
print(torch.norm(y - yhat))
# Construct full Jacobian.
def func(x):
return mha(x.reshape(T, N, D), attn_mask=attn_mask).reshape(-1)
jac = torch.autograd.functional.jacobian(func, x)
# Exact diagonal block of Jacobian.
jac = jac.reshape(T, D, T, D)
blocks = []
for i in range(T):
blocks.append(jac[i, :, i, :])
jac_block_diag = torch.block_diag(*blocks)
# Simulated diagonal block of Jacobian.
def selfonly_func(x):
return mha(x.reshape(T, N, D), attn_mask=attn_mask, rm_nonself_grads=True).reshape(-1)
simulated_jac_block_diag = torch.autograd.functional.jacobian(selfonly_func, x)
print(torch.norm(simulated_jac_block_diag - jac_block_diag))
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3)
axs[0].imshow(jac_block_diag)
axs[1].imshow(simulated_jac_block_diag)
axs[2].imshow(torch.abs(simulated_jac_block_diag - jac_block_diag))
plt.savefig("jacobian.png")