in models/spatial/attention.py [0:0]
def test_einsum_op():
def squared_dist_fn(x, y):
return torch.norm(x - y)**2
def squared_dist_fn2(x, y):
return dot_product_fn(x, x) + dot_product_fn(y, y) - 2 * dot_product_fn(x, y)
def dot_product_fn(x, y):
return torch.sum(x * y)
T, N, H, D = 4, 1, 1, 10
q = torch.randn(T, N, H, D)
k = torch.randn(T, N, H, D)
squared_dist = torch.zeros(T, T, N, H)
squared_dist2 = torch.zeros(T, T, N, H)
dot_product = torch.zeros(T, T, N, H)
for t in range(T):
for s in range(T):
for n in range(N):
for h in range(H):
squared_dist[t, s] = squared_dist_fn(q[t, n, h], k[s, n, h])
dot_product[t, s] = dot_product_fn(q[t, n, h], k[s, n, h])
squared_dist2[t, s] = squared_dist_fn2(q[t, n, h], k[s, n, h])
einsum_sqdist = torch.einsum('tbhd,tbhd->tbh', q, q).unsqueeze(1) + torch.einsum('sbhd,sbhd->sbh', k, k).unsqueeze(0) - 2 * torch.einsum('tbhd,sbhd->tsbh', q, k)
einsum_dotproduct = torch.einsum('tbhd,sbhd->tsbh', q, k)
print("squared dist", squared_dist.reshape(T, T))
print("squared dist 2", squared_dist2.reshape(T, T))
print("einsum squared dist", einsum_sqdist.reshape(T, T))
print("dot product", dot_product.reshape(T, T))
print("einsum dot product", einsum_dotproduct.reshape(T, T))