def forward()

in src/run_paraphrase.py [0:0]


    def forward(self, x1, x2, len1, len2):
        """Compute binary prediction.

        x1: Tensor of shape (B, L1)
        x2: Tensor of shape (B, L2)
        len1: Tensor of shape (B,)
        len2: Tensor of shape (B,)
        """
        B = x1.shape[0]
        reps1 = self.roberta_model(x1, features_only=True, 
                                   return_all_hiddens=True)[1]['inner_states']  # Layer -> L, B, d
        reps2 = self.roberta_model(x2, features_only=True, 
                                   return_all_hiddens=True)[1]['inner_states']  
        feats = [[] for i in range(B)]
        for layer in self.layers:
            h1 = reps1[layer].permute(1, 0, 2)  # B, L1, d
            h2 = reps2[layer].permute(1, 0, 2)  # B, L2, d
            h1_normed = (h1 / torch.linalg.norm(h1, dim=2, keepdim=True))  # B, L1, d
            h2_normed = (h2 / torch.linalg.norm(h2, dim=2, keepdim=True))  # B, L2, d
            dots = torch.matmul(h1_normed, h2_normed.permute(0, 2, 1))  # B, L1, L2
            for i in range(B):
                cur_dots = dots[i,:len1[i],:len2[i]]
                s1 = torch.mean(torch.max(cur_dots, dim=0)[0])
                s2 = torch.mean(torch.max(cur_dots, dim=1)[0])
                score = 2 * s1 * s2 / (s1 + s2)
                feats[i].append(score)
        feat_mat = torch.stack([torch.stack(v) for v in feats])  # B, Layers
        feat_mat = self.batchnorm(feat_mat)
        out = self.logit_scale * self.output(feat_mat).squeeze(-1)
        return out