in videoalignment/temporal_match_kernel.py [0:0]
def tmk(ts, xs, a, ms, Ts):
block_size = 500
ts = ts.unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, len_ts, d)
a_xp = torch.cat([a, a], dim=1).unsqueeze(2).unsqueeze(0) # (1, T, 2m, 1)
ms = (ms.unsqueeze(0) / Ts.unsqueeze(1)).unsqueeze(0).unsqueeze(3) # (1, Ts, m, 1)
for t in range(0, xs.size()[1], block_size):
args = ms * xs[:, t : t + block_size].unsqueeze(1).unsqueeze(
1
) # (b_s, Ts, m, len_ts)
sin_cos = a_xp * torch.cat(
[torch.sin(args), torch.cos(args)], dim=2
) # (b_s, Ts, 2m, len_ts)
sin_cos = sin_cos.unsqueeze(4) # (b_s, Ts, 2m, len_ts, 1)
this_fv = torch.sum(
sin_cos * ts[:, :, :, t : t + block_size], dim=3
) # (b_s, Ts, 2m, d)
if t == 0:
fv = this_fv
else:
fv += this_fv
return fv