def tmk()

in videoalignment/temporal_match_kernel.py [0:0]


def tmk(ts, xs, a, ms, Ts):
    block_size = 500
    ts = ts.unsqueeze(1).unsqueeze(1)  # (b_s, 1, 1, len_ts, d)
    a_xp = torch.cat([a, a], dim=1).unsqueeze(2).unsqueeze(0)  # (1, T, 2m, 1)
    ms = (ms.unsqueeze(0) / Ts.unsqueeze(1)).unsqueeze(0).unsqueeze(3)  # (1, Ts, m, 1)

    for t in range(0, xs.size()[1], block_size):
        args = ms * xs[:, t : t + block_size].unsqueeze(1).unsqueeze(
            1
        )  # (b_s, Ts, m, len_ts)
        sin_cos = a_xp * torch.cat(
            [torch.sin(args), torch.cos(args)], dim=2
        )  # (b_s, Ts, 2m, len_ts)
        sin_cos = sin_cos.unsqueeze(4)  # (b_s, Ts, 2m, len_ts, 1)
        this_fv = torch.sum(
            sin_cos * ts[:, :, :, t : t + block_size], dim=3
        )  # (b_s, Ts, 2m, d)
        if t == 0:
            fv = this_fv
        else:
            fv += this_fv

    return fv