in custom/evaluate_utils.py [0:0]
def generate_completions(model, generator, fairseq_generator, itr, eval_prefix_length, eval_completion_length, topk, topp, num_samples, beam_size, include_prefix=True):
completions = []
completion_metrics = Metrics()
actual_metrics = Metrics()
for n, sample in enumerate(tqdm(itr)):
input_sequence = sample['net_input']['src_tokens']
prefix_batch = batch_input_sequence_by_prefix_length(input_sequence, eval_prefix_length)
prefix_batch = prefix_batch.cuda()
if input_sequence.size(1) < eval_prefix_length:
continue
if beam_size > 1:
assert topk == 1, 'with greedy topk must be 1'
assert topp == 0.0, 'with greedy topp must be 0'
sample['net_input']['src_tokens'] = prefix_batch
res = fairseq_generator.generate([model], sample, prefix_batch, bos_token=0) # prefix is there in preds!
pred_completion = [res[i][0]['tokens'][eval_prefix_length:-1].cpu().tolist() for i in range(len(res))]
elif beam_size == 1:
pred_completion = generator.generate_completion(model, prefix_batch, eval_completion_length, topk, topp)
pred_completion = pred_completion.cpu().tolist()
completion_metrics.update(pred_completion)
actual_metrics.update(input_sequence)
if include_prefix:
prefix_batch = prefix_batch.cpu().tolist()
pred_completion = [prefix + completion for
prefix, completion in zip(prefix_batch, pred_completion)]
completions.extend(pred_completion)
if n == num_samples:
break
completion_metrics = completion_metrics.report('generated')
actual_metrics = actual_metrics.report('actual')
return completions, completion_metrics, actual_metrics