in videoalignment/circulant_temporal_encoding.py [0:0]
def merge(self, fv_a, fv_b, offsets, max_len):
max_len = offsets.shape[-1] // 2
device = fv_a.device
ts_ar = fv_a[:, :, : self.m]
ts_ai = fv_a[:, :, self.m :]
ts_br = fv_b[:, :, : self.m]
ts_bi = fv_b[:, :, self.m :]
length_power_2 = 2 ** int(np.ceil(np.log2(max_len)))
Rrs = []
# s(x, y) contains the scores for y shifted from 0 to max_len-1
# compute s(a, b) and s(b, a), then concat flip(s(a, b)) and s(b, a) to have all possible time-shifts
for i, (ts_ar, ts_ai, ts_br, ts_bi) in enumerate(
[(ts_ar, ts_ai, ts_br, ts_bi), (ts_br, ts_bi, ts_ar, ts_ai)]
):
Qir, Qii = ts_ar, ts_ai
Bir, Bii = ts_br, ts_bi
# See equation (10) in the paper
Qdenr = torch.sum(Qir * Qir + Qii * Qii, dim=1, keepdim=True) + self.lmbda
Sr = torch.sum((Qir * Bir + Qii * Bii) / Qdenr, dim=1, keepdim=True)
Si = torch.sum((-Qii * Bir + Qir * Bii) / Qdenr, dim=1, keepdim=True)
# padding with 0 up to the closest power of 2
s0, s1, s2 = Sr.size()
padding = length_power_2 - s2
if padding > 0:
zero_pad = torch.zeros(
s0, s1, padding, dtype=torch.float32, device=device
)
Sr = torch.cat([Sr, zero_pad], dim=2)
Si = torch.cat([Si, zero_pad], dim=2)
# With pytorch.fft, real and im are in the same tensor
Sir = torch.stack((Sr, Si), dim=-1)
R = torch.ifft(Sir, 1)
Rr = R[..., 0]
# Shape [b, 1, length_power_2], remove dim=1 and keep only up to max_len -> new shape [b, max_len]
Rr = Rr.permute(0, 2, 1).squeeze(2)[:, :max_len]
Rrs.append(Rr)
Rrs[0] = Rrs[0].flip(dims=[1])
outs = torch.cat(Rrs, dim=1)
return outs