in align/train.py [0:0]
def extract_scores(batch, criss_features, aligner, info, configs):
coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg = info
all_scores = list()
for i in range(0, len(batch), configs.batch_size):
subbatch = batch[i:i+configs.batch_size]
src_words, trg_words = zip(*subbatch)
features = torch.tensor(
[
[
matched_coocc[x[0]][x[1]],
semi_matched_coocc[x[0]][x[1]],
coocc[x[0]][x[1]],
freq_src[x[0]],
freq_trg[x[1]]
] for x in subbatch
]
).float().to(configs.device).reshape(-1, 5)
if configs.use_criss:
subbatch_crissfeat = torch.cat(criss_features[i:i+configs.batch_size], dim=0)
features = torch.cat((subbatch_crissfeat, features), dim=-1).detach()
scores = aligner(features).squeeze(-1)
all_scores.append(scores)
return torch.cat(all_scores, dim=0)