in eval_lm.py [0:0]
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'
utils.import_user_module(parsed_args)
print(parsed_args)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
task = tasks.setup_task(parsed_args)
# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(':'),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
)
for arg in vars(parsed_args).keys():
if arg not in {
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
'output_size_dictionary', 'add_bos_token',
}:
setattr(args, arg, getattr(parsed_args, arg))
# reduce tokens per sample by the required context window size
args.tokens_per_sample -= args.context_window
task = tasks.setup_task(args)
# Load dataset splits
task.load_dataset(args.gen_subset)
dataset = task.dataset(args.gen_subset)
if args.context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=args.tokens_per_sample,
context_window=args.context_window,
pad_idx=task.source_dictionary.pad(),
)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
if args.fp16:
model.half()
if use_cuda:
model.cuda()
assert len(models) > 0
print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
score_sum = 0.
count = 0
if args.remove_bpe is not None:
if args.remove_bpe == 'sentencepiece':
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(
i
for i in range(len(task.source_dictionary))
if task.source_dictionary[i].endswith(bpe_cont)
)
bpe_len = len(bpe_cont)
else:
bpe_toks = None
bpe_len = 0
word_stats = dict()
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
if 'net_input' not in sample:
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])
for hypos_i in hypos:
hypo = hypos_i[0]
tokens = hypo['tokens']
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()
if args.add_bos_token:
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]
skipped_toks = 0
if bpe_toks is not None:
for i in range(tgt_len - 1):
if tokens[i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(tokens[inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats:
w = ''
word_prob = []
is_bpe = False
for i in range(len(tokens)):
w_ind = tokens[i].item()
w += task.source_dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
is_bpe = True
else:
word_prob.append((w, pos_scores[i].item()))
next_prob = None
ind = i + 1
while ind < len(tokens):
if pos_scores[ind].item() != 0:
next_prob = pos_scores[ind]
break
ind += 1
word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
is_bpe = False
w = ''
if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
print(ws)