in fairseq_cli/eval_lm.py [0:0]
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'
utils.import_user_module(parsed_args)
logger.info(parsed_args)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
task = tasks.setup_task(parsed_args)
# Load ensemble
logger.info('loading model(s) from {}'.format(parsed_args.path))
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(os.pathsep),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
)
for arg in vars(parsed_args).keys():
if arg not in {
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
'output_size_dictionary', 'add_bos_token',
}:
setattr(args, arg, getattr(parsed_args, arg))
# reduce tokens per sample by the required context window size
args.tokens_per_sample -= args.context_window
task = tasks.setup_task(args)
# Load dataset splits
task.load_dataset(args.gen_subset)
dataset = task.dataset(args.gen_subset)
if args.context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=args.tokens_per_sample,
context_window=args.context_window,
pad_idx=task.source_dictionary.pad(),
)
logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
if args.fp16:
model.half()
if use_cuda:
model.cuda()
assert len(models) > 0
logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
score_sum = 0.
count = 0
if args.remove_bpe is not None:
if args.remove_bpe == 'sentencepiece':
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = {
i
for i in range(len(task.source_dictionary))
if task.source_dictionary[i].endswith(bpe_cont)
}
bpe_len = len(bpe_cont)
else:
bpe_toks = None
bpe_len = 0
word_stats = dict()
wps_meter = TimeMeter()
for sample in progress:
if 'net_input' not in sample:
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])
for i, hypos_i in enumerate(hypos):
hypo = hypos_i[0]
sample_id = sample['id'][i]
tokens = hypo['tokens']
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()
if args.add_bos_token:
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]
skipped_toks = 0
if bpe_toks is not None:
for i in range(tgt_len - 1):
if tokens[i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
logger.info(
'skipping tokens with inf scores:',
task.target_dictionary.string(tokens[inf_scores.nonzero()])
)
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats:
w = ''
word_prob = []
is_bpe = False
for i in range(len(tokens)):
w_ind = tokens[i].item()
w += task.source_dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
is_bpe = True
else:
word_prob.append((w, pos_scores[i].item()))
next_prob = None
ind = i + 1
while ind < len(tokens):
if pos_scores[ind].item() != 0:
next_prob = pos_scores[ind]
break
ind += 1
word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
is_bpe = False
w = ''
if args.output_word_probs:
logger.info(
str(int(sample_id)) + " "
+ ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
)
wps_meter.update(sample['ntokens'])
progress.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
gen_timer.n, gen_timer.sum, 1. / gen_timer.avg
))
logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
avg_nll_loss, 2**avg_nll_loss
))
if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
logger.info(ws)