in videoalignment/temporal_match_kernel.py [0:0]
def merge(self, fv_a, fv_b, offsets):
device = get_device(self)
eps = 1e-8
if "feat" in self.normalization:
a_xp = self.a.unsqueeze(0).unsqueeze(-1)
a_xp = torch.cat([a_xp, a_xp], dim=2)
fv_a_0 = fv_a / torch.sqrt(a_xp)
fv_b_0 = fv_b / torch.sqrt(a_xp)
norm_a = torch.sqrt(torch.sum(fv_a_0 ** 2, dim=3, keepdim=True) + eps) + eps
norm_b = torch.sqrt(torch.sum(fv_b_0 ** 2, dim=3, keepdim=True) + eps) + eps
fv_a = fv_a / norm_a
fv_b = fv_b / norm_b
if "freq" in self.normalization:
norm_a = (
torch.sqrt(torch.sum(fv_a ** 2, dim=2, keepdim=True) / self.m + eps)
+ eps
)
norm_b = (
torch.sqrt(torch.sum(fv_b ** 2, dim=2, keepdim=True) / self.m + eps)
+ eps
)
fv_a = fv_a / norm_a
fv_b = fv_b / norm_b
elif self.normalization == "matrix":
norm_a = (
torch.sqrt(
torch.sum(torch.sum(fv_a ** 2, dim=-1, keepdim=True), dim=2) + eps
)
+ eps
) # (b_s, T, 1)
norm_b = (
torch.sqrt(
torch.sum(torch.sum(fv_b ** 2, dim=-1, keepdim=True), dim=2) + eps
)
+ eps
) # (b_s, T, 1)
fv_a_sin = fv_a[:, :, : self.m] # (b_s, T, m, d)
fv_a_cos = fv_a[:, :, self.m :] # (b_s, T, m, d)
fv_b_sin = fv_b[:, :, : self.m] # (b_s, T, m, d)
fv_b_cos = fv_b[:, :, self.m :] # (b_s, T, m, d)
self.ms = self.ms.to(device)
xs = offsets.float()
ms = self.ms.unsqueeze(1) # (m, 1)
dot_sin_sin = torch.sum(
fv_a_sin * fv_b_sin, dim=3, keepdim=True
) # (b_s, T, m, 1)
dot_sin_cos = torch.sum(
fv_a_sin * fv_b_cos, dim=3, keepdim=True
) # (b_s, T, m, 1)
dot_cos_cos = torch.sum(
fv_a_cos * fv_b_cos, dim=3, keepdim=True
) # (b_s, T, m, 1)
dot_cos_sin = torch.sum(
fv_a_cos * fv_b_sin, dim=3, keepdim=True
) # (b_s, T, m, 1)
T = torch.tensor(self.T, dtype=torch.float32, requires_grad=False)
T = T.to(device)
T = T.unsqueeze(0).unsqueeze(2).unsqueeze(2)
cos_delta = torch.cos(
ms.unsqueeze(0).unsqueeze(0) * xs.unsqueeze(1).unsqueeze(1) / T
) # (b_s, T, m, delta)
sin_delta = torch.sin(
ms.unsqueeze(0).unsqueeze(0) * xs.unsqueeze(1).unsqueeze(1) / T
) # (b_s, T, m, delta)
dots = (
dot_sin_sin * cos_delta
+ dot_sin_cos * sin_delta
+ dot_cos_cos * cos_delta
- dot_cos_sin * sin_delta
) # (b_s, T, m, delta)
dots = torch.sum(dots, dim=2) # (b_s, T, delta)
if self.normalization == "matrix":
dots = dots / (norm_a * norm_b)
elif self.normalization == "freq":
dots = dots / self.m
elif self.normalization in ["feat", "feat_freq"]:
dots = dots / 512
dots = torch.mean(dots, dim=1)
return dots