in recipes/sota/2019/rescoring/forward_lm.py [0:0]
def predict_batch(sentences, model, fairseq_dict, max_len):
encoded_input = []
padded_input = []
ppls = []
total_loss = 0.0
nwords = 0
for sentence in sentences:
encoded_input.append([fairseq_dict.index(token) for token in sentence])
assert (
len(encoded_input[-1]) <= max_len
), "Error in the input length, it should be less than max_len {}".format(
max_len
)
if len(encoded_input[-1]) < max_len:
padded_input.append(
[fairseq_dict.eos()]
+ encoded_input[-1]
+ [fairseq_dict.eos()] * (max_len - len(encoded_input[-1]))
)
else:
padded_input.append([fairseq_dict.eos()] + encoded_input[-1])
x = torch.LongTensor(padded_input).cuda()
with torch.no_grad():
y = model.forward(x)[0]
if model.adaptive_softmax is not None:
logprobs = (
model.adaptive_softmax.get_log_prob(y, None).detach().cpu().numpy()
)
else:
logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()
for index, input_i in enumerate(encoded_input):
loss = numpy.sum(logprobs[index, numpy.arange(len(input_i)), input_i])
loss += logprobs[index, len(input_i), fairseq_dict.eos()]
ppls.append(loss)
total_loss += loss
nwords += len(input_i) + 1
return ppls, total_loss, nwords