def merge()

in videoalignment/temporal_match_kernel.py [0:0]


    def merge(self, fv_a, fv_b, offsets):
        device = get_device(self)
        eps = 1e-8
        if "feat" in self.normalization:
            a_xp = self.a.unsqueeze(0).unsqueeze(-1)
            a_xp = torch.cat([a_xp, a_xp], dim=2)
            fv_a_0 = fv_a / torch.sqrt(a_xp)
            fv_b_0 = fv_b / torch.sqrt(a_xp)
            norm_a = torch.sqrt(torch.sum(fv_a_0 ** 2, dim=3, keepdim=True) + eps) + eps
            norm_b = torch.sqrt(torch.sum(fv_b_0 ** 2, dim=3, keepdim=True) + eps) + eps
            fv_a = fv_a / norm_a
            fv_b = fv_b / norm_b

        if "freq" in self.normalization:
            norm_a = (
                torch.sqrt(torch.sum(fv_a ** 2, dim=2, keepdim=True) / self.m + eps)
                + eps
            )
            norm_b = (
                torch.sqrt(torch.sum(fv_b ** 2, dim=2, keepdim=True) / self.m + eps)
                + eps
            )
            fv_a = fv_a / norm_a
            fv_b = fv_b / norm_b

        elif self.normalization == "matrix":
            norm_a = (
                torch.sqrt(
                    torch.sum(torch.sum(fv_a ** 2, dim=-1, keepdim=True), dim=2) + eps
                )
                + eps
            )  # (b_s, T, 1)
            norm_b = (
                torch.sqrt(
                    torch.sum(torch.sum(fv_b ** 2, dim=-1, keepdim=True), dim=2) + eps
                )
                + eps
            )  # (b_s, T, 1)

        fv_a_sin = fv_a[:, :, : self.m]  # (b_s, T, m, d)
        fv_a_cos = fv_a[:, :, self.m :]  # (b_s, T, m, d)
        fv_b_sin = fv_b[:, :, : self.m]  # (b_s, T, m, d)
        fv_b_cos = fv_b[:, :, self.m :]  # (b_s, T, m, d)

        self.ms = self.ms.to(device)

        xs = offsets.float()
        ms = self.ms.unsqueeze(1)  # (m, 1)

        dot_sin_sin = torch.sum(
            fv_a_sin * fv_b_sin, dim=3, keepdim=True
        )  # (b_s, T, m, 1)
        dot_sin_cos = torch.sum(
            fv_a_sin * fv_b_cos, dim=3, keepdim=True
        )  # (b_s, T, m, 1)
        dot_cos_cos = torch.sum(
            fv_a_cos * fv_b_cos, dim=3, keepdim=True
        )  # (b_s, T, m, 1)
        dot_cos_sin = torch.sum(
            fv_a_cos * fv_b_sin, dim=3, keepdim=True
        )  # (b_s, T, m, 1)

        T = torch.tensor(self.T, dtype=torch.float32, requires_grad=False)
        T = T.to(device)
        T = T.unsqueeze(0).unsqueeze(2).unsqueeze(2)
        cos_delta = torch.cos(
            ms.unsqueeze(0).unsqueeze(0) * xs.unsqueeze(1).unsqueeze(1) / T
        )  # (b_s, T, m, delta)
        sin_delta = torch.sin(
            ms.unsqueeze(0).unsqueeze(0) * xs.unsqueeze(1).unsqueeze(1) / T
        )  # (b_s, T, m, delta)

        dots = (
            dot_sin_sin * cos_delta
            + dot_sin_cos * sin_delta
            + dot_cos_cos * cos_delta
            - dot_cos_sin * sin_delta
        )  # (b_s, T, m, delta)
        dots = torch.sum(dots, dim=2)  # (b_s, T, delta)

        if self.normalization == "matrix":
            dots = dots / (norm_a * norm_b)
        elif self.normalization == "freq":
            dots = dots / self.m
        elif self.normalization in ["feat", "feat_freq"]:
            dots = dots / 512
        dots = torch.mean(dots, dim=1)
        return dots