def test_einsum_op()

in models/spatial/attention.py [0:0]


def test_einsum_op():

    def squared_dist_fn(x, y):
        return torch.norm(x - y)**2

    def squared_dist_fn2(x, y):
        return dot_product_fn(x, x) + dot_product_fn(y, y) - 2 * dot_product_fn(x, y)

    def dot_product_fn(x, y):
        return torch.sum(x * y)

    T, N, H, D = 4, 1, 1, 10

    q = torch.randn(T, N, H, D)
    k = torch.randn(T, N, H, D)

    squared_dist = torch.zeros(T, T, N, H)
    squared_dist2 = torch.zeros(T, T, N, H)
    dot_product = torch.zeros(T, T, N, H)
    for t in range(T):
        for s in range(T):
            for n in range(N):
                for h in range(H):
                    squared_dist[t, s] = squared_dist_fn(q[t, n, h], k[s, n, h])
                    dot_product[t, s] = dot_product_fn(q[t, n, h], k[s, n, h])
                    squared_dist2[t, s] = squared_dist_fn2(q[t, n, h], k[s, n, h])

    einsum_sqdist = torch.einsum('tbhd,tbhd->tbh', q, q).unsqueeze(1) + torch.einsum('sbhd,sbhd->sbh', k, k).unsqueeze(0) - 2 * torch.einsum('tbhd,sbhd->tsbh', q, k)
    einsum_dotproduct = torch.einsum('tbhd,sbhd->tsbh', q, k)

    print("squared dist", squared_dist.reshape(T, T))
    print("squared dist 2", squared_dist2.reshape(T, T))
    print("einsum squared dist", einsum_sqdist.reshape(T, T))

    print("dot product", dot_product.reshape(T, T))
    print("einsum dot product", einsum_dotproduct.reshape(T, T))