def test_attn_mask()

in models/spatial/attention.py [0:0]


def test_attn_mask():

    torch.set_default_dtype(torch.float64)

    T, N, D = 8, 1, 20

    attn_mask = torch.triu(torch.ones(T, T), diagonal=1) * -1e12

    x = torch.randn(T * N * D).requires_grad_(True)
    mha = L2MultiheadAttention(D, 1)

    y = mha(x.reshape(T, N, D), attn_mask=attn_mask)
    yhat = mha(x.reshape(T, N, D), attn_mask=attn_mask, rm_nonself_grads=True)
    print(torch.norm(y - yhat))

    # Construct full Jacobian.
    def func(x):
        return mha(x.reshape(T, N, D), attn_mask=attn_mask).reshape(-1)

    jac = torch.autograd.functional.jacobian(func, x)

    # Exact diagonal block of Jacobian.
    jac = jac.reshape(T, D, T, D)
    blocks = []
    for i in range(T):
        blocks.append(jac[i, :, i, :])
    jac_block_diag = torch.block_diag(*blocks)

    # Simulated diagonal block of Jacobian.
    def selfonly_func(x):
        return mha(x.reshape(T, N, D), attn_mask=attn_mask, rm_nonself_grads=True).reshape(-1)
    simulated_jac_block_diag = torch.autograd.functional.jacobian(selfonly_func, x)

    print(torch.norm(simulated_jac_block_diag - jac_block_diag))

    import matplotlib.pyplot as plt

    fig, axs = plt.subplots(1, 3)
    axs[0].imshow(jac_block_diag)
    axs[1].imshow(simulated_jac_block_diag)
    axs[2].imshow(torch.abs(simulated_jac_block_diag - jac_block_diag))
    plt.savefig("jacobian.png")