in src/run_paraphrase.py [0:0]
def forward(self, x1, x2, len1, len2):
"""Compute binary prediction.
x1: Tensor of shape (B, L1)
x2: Tensor of shape (B, L2)
len1: Tensor of shape (B,)
len2: Tensor of shape (B,)
"""
B = x1.shape[0]
reps1 = self.roberta_model(x1, features_only=True,
return_all_hiddens=True)[1]['inner_states'] # Layer -> L, B, d
reps2 = self.roberta_model(x2, features_only=True,
return_all_hiddens=True)[1]['inner_states']
feats = [[] for i in range(B)]
for layer in self.layers:
h1 = reps1[layer].permute(1, 0, 2) # B, L1, d
h2 = reps2[layer].permute(1, 0, 2) # B, L2, d
h1_normed = (h1 / torch.linalg.norm(h1, dim=2, keepdim=True)) # B, L1, d
h2_normed = (h2 / torch.linalg.norm(h2, dim=2, keepdim=True)) # B, L2, d
dots = torch.matmul(h1_normed, h2_normed.permute(0, 2, 1)) # B, L1, L2
for i in range(B):
cur_dots = dots[i,:len1[i],:len2[i]]
s1 = torch.mean(torch.max(cur_dots, dim=0)[0])
s2 = torch.mean(torch.max(cur_dots, dim=1)[0])
score = 2 * s1 * s2 / (s1 + s2)
feats[i].append(score)
feat_mat = torch.stack([torch.stack(v) for v in feats]) # B, Layers
feat_mat = self.batchnorm(feat_mat)
out = self.logit_scale * self.output(feat_mat).squeeze(-1)
return out